好的,让我们从使代码更通用开始,稍后我将讨论性能方面。
我一般不使用:
import numpy as np
cimport numpy as np
我个人喜欢用不同的名字
cimport
因为它有助于保持C端和麻木的Python端分开。所以这个答案我会用
import numpy as np
cimport numpy as cnp
我也会做
lower_limit
和
upper_limit
函数的参数。在您的案例中,这些可能是静态(或全局)定义的,但它使示例更加独立。所以起点是代码的一个稍微修改过的版本:
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
赛通的一个很好的特点是
fused types
,因此您可以很容易地将此函数泛化为不同的类型。您的方法只适用于32位整数数组(至少如果
int
在您的计算机上是32位)。很容易支持更多的数组类型:
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
当然,如果需要,可以添加更多类型。其优点是新版本在旧版本失败的情况下工作:
>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0
现在更一般了,让我们来看看你的函数实际上做了什么:
-
创建一个布尔数组,其中元素高于下限
-
创建一个布尔数组,其中元素低于上限
-
通过两个布尔数组中的位和创建布尔数组。
-
创建一个只包含布尔值掩码为真的元素的新数组
-
只能从最后一个数组中提取一个元素
为什么要创建这么多数组?我的意思是,你可以简单地计算在限制范围内有多少个元素,取一个介于0和限制范围内元素数量之间的随机整数,然后取任何元素。
会是
在结果数组的那个索引处。
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
# Count the number of elements that are within the limits
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
# Take a random index
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
# Go through the array again and take the element at the random index that
# is within the bounds
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
它不会快很多,但会节省很多内存。因为你没有中间数组,你根本不需要内存视图,但是如果你愿意,你可以替换
cnp.ndarray[int_or_float] arr
在参数列表中
int_or_float[:]
甚至
int_or_float[::1] arr
并对memoryview进行操作(可能不会更快,但也不会更慢)。
我通常喜欢numba而不是cython(至少如果我只是在使用它),所以让我们将其与该代码的numba版本进行比较:
import numba as nb
import numpy as np
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
以及
numexpr
变体:
import numexpr
np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
好吧,让我们做一个基准:
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
因此,如果不使用中间数组,您可以快5倍左右,如果您要使用numba,您可以得到另一个因子5(似乎我在那里缺少了一些可能的cython优化,numba通常快2倍或与cython一样快)。因此,使用Numba解决方案,您可以更快地获得20倍。
数字表达式
在这里并不具有可比性,主要是因为您不能在那里使用布尔数组索引。
差异将取决于数组的内容和限制。您还必须测量应用程序的性能。
作为旁白:如果下限和上限一般不改变,最快的解决方案是过滤数组一次,然后调用
np.random.choice
好几次。那可能是
数量级更快
.
lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]
def search_cached():
return np.random.choice(filtered_array)
%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
所以速度快了近1000倍,根本不需要赛通或者麻木。但这是一个特殊的案例,可能对你没有帮助。
如果您想自己动手,基准设置就在这里(基于Jupyter笔记本/实验室,因此
%
-符号):
%load_ext cython
%%cython
cimport numpy as cnp
import numpy as np
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
import numexpr
import numba as nb
import numpy as np
def search_numexpr(arr, l, u):
return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
%matplotlib widget
import matplotlib.pyplot as plt
plt.style.use('ggplot')
b.plot()