这里是一个使用查找表的O(1)解决方案,与OP的(第一个)解决方案相比
numpy.searchsorted
. 这不是百分之百的公平,因为OP的解决方案没有矢量化。无论如何,时间安排:
True # results equal
True # results equal
0.08163515606429428 # lookup
2.1996873939642683 # OP
0.016975965932942927 # numpy.searchsorted
对于这个小列表大小
seachsorted
即使是O(log n),也会获胜。
代码:
import numpy as np
class find_next:
def __init__(self, a, max_bins=100000):
self.a = np.sort(a)
self.low = self.a[0]
self.high = self.a[-1]
self.span = self.a[-1] - self.a[0]
self.damin = np.diff(self.a).min()
if self.span // self.damin > max_bins:
raise ValueError('a too unevenly spaced for max_bins')
self.lut = np.searchsorted(self.a, np.linspace(self.low, self.high,
max_bins + 1))
self.no_bins = max_bins
def f_pp(self, x):
i = np.array((x-self.low)/self.span * self.no_bins, int)
return self.a[self.lut[i + (x > self.a[self.lut[i]])]]
def lowest(self, x):
return self.a[self.a >= x][0]
def f_ss(self, x):
return self.a[self.a.searchsorted(x)]
a = np.array([2, 1, 3, 4, 5, 6, 7, 8, 9])
x = np.random.uniform(1, 9, (10000,))
fn = find_next(a)
sol_pp = fn.f_pp(x)
sol_OP = [fn.lowest(xi) for xi in x]
sol_ss = fn.f_ss(x)
print(np.all(sol_OP == sol_pp))
print(np.all(sol_OP == sol_ss))
from timeit import timeit
kwds = dict(globals=globals(), number=10000)
print(timeit('fn.f_pp(x)', **kwds))
print(timeit('[fn.lowest(xi) for xi in x]', **kwds))
print(timeit('fn.f_ss(x)', **kwds))