代码之家  ›  专栏  ›  技术社区  ›  Paul Terwilliger

优化:从大于(或等于)`x的数组中返回最小值`

  •  0
  • Paul Terwilliger  · 技术社区  · 7 年前

    编辑:我的问题与建议的重复不同,因为我已经有了一种实现方法 lowest . 我的问题不是如何实施 最低的 ,而是如何优化 最低的 跑得更快。

    假设我有一个数组 a . 例如:

    import numpy as np
    a = np.array([2, 1, 3, 4, 5, 6, 7, 8, 9])
    

    假设我有一个浮点数 x . 例如:

    x = 6.5
    

    我想返回中的最小值 也大于或等于 x个 . 所以在这种情况下。。。

    print lowest(a, x)
    >>> 7
    

    我尝试了许多函数来代替 最低的 . 例如:

    def lowest(a, x):
    """ `a` should be a sorted numpy array"""
        return lowest[lowest >= x][0]
    
    def lowest(a, x):
    """ `a` should be a sorted `list`, not a numpy array"""
        k = sorted(a + [x])
        return k[k.index(x) + 1]
    

    但是,功能 最低的 仍然是我代码的瓶颈,约占90%。

    是否有更快的方法实现该功能 最低的 ?

    关于我的代码的一些规则:

    • 可以假设长度为10
    • 功能 最低的 至少运行10万次。这可能是一个设计问题,但我感兴趣的是是否有更快的实现我的问题。
    • 可以在运行这些循环之前进行预处理。 x个 将有所不同,但 不会。
    • 可以假设 a[0] <= x <= a[-1] 始终是 True
    1 回复  |  直到 7 年前
        1
  •  2
  •   Paul Panzer    7 年前

    这里是一个使用查找表的O(1)解决方案,与OP的(第一个)解决方案相比 numpy.searchsorted . 这不是百分之百的公平,因为OP的解决方案没有矢量化。无论如何,时间安排:

    True                  # results equal
    True                  # results equal
    0.08163515606429428   # lookup
    2.1996873939642683    # OP
    0.016975965932942927  # numpy.searchsorted
    

    对于这个小列表大小 seachsorted 即使是O(log n),也会获胜。

    代码:

    import numpy as np
    
    class find_next:
        def __init__(self, a, max_bins=100000):
            self.a = np.sort(a)
            self.low = self.a[0]
            self.high = self.a[-1]
            self.span = self.a[-1] - self.a[0]
            self.damin = np.diff(self.a).min()
            if self.span // self.damin > max_bins:
                raise ValueError('a too unevenly spaced for max_bins')
            self.lut = np.searchsorted(self.a, np.linspace(self.low, self.high,
                                                           max_bins + 1))
            self.no_bins = max_bins
        def f_pp(self, x):
            i = np.array((x-self.low)/self.span * self.no_bins, int)
            return self.a[self.lut[i + (x > self.a[self.lut[i]])]]
        def lowest(self, x):
            return self.a[self.a >= x][0]
        def f_ss(self, x):
            return self.a[self.a.searchsorted(x)]
    
    a = np.array([2, 1, 3, 4, 5, 6, 7, 8, 9])
    
    x = np.random.uniform(1, 9, (10000,))
    
    fn = find_next(a)
    sol_pp = fn.f_pp(x)
    sol_OP = [fn.lowest(xi) for xi in x]
    sol_ss = fn.f_ss(x)
    
    print(np.all(sol_OP == sol_pp))
    print(np.all(sol_OP == sol_ss))
    
    from timeit import timeit
    kwds = dict(globals=globals(), number=10000)
    
    print(timeit('fn.f_pp(x)', **kwds))
    print(timeit('[fn.lowest(xi) for xi in x]', **kwds))
    print(timeit('fn.f_ss(x)', **kwds))