代码之家  ›  专栏  ›  技术社区  ›  Waiski

Cython-高效过滤类型化的内存视图

  •  1
  • Waiski  · 技术社区  · 6 年前

    此cython函数返回numpy数组元素中某些限制内的随机元素:

    cdef int search(np.ndarray[int] pool):
      cdef np.ndarray[int] limited
      limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
      return np.random.choice(limited)
    

    这个很好用。但是,这个函数对代码的性能非常关键。类型化的memoryview显然比numpy数组快得多,但是它们不能像上面那样被过滤。

    如何使用类型化的memoryview编写一个与上面相同的函数?还是有其他方法来提高函数的性能?

    1 回复  |  直到 6 年前
        1
  •  5
  •   MSeifert    6 年前

    好的,让我们从使代码更通用开始,稍后我将讨论性能方面。

    我一般不使用:

    import numpy as np
    cimport numpy as np
    

    我个人喜欢用不同的名字 cimport 因为它有助于保持C端和麻木的Python端分开。所以这个答案我会用

    import numpy as np
    cimport numpy as cnp
    

    我也会做 lower_limit upper_limit 函数的参数。在您的案例中,这些可能是静态(或全局)定义的,但它使示例更加独立。所以起点是代码的一个稍微修改过的版本:

    cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
        cdef cnp.ndarray[int] limited
        limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
        return np.random.choice(limited)
    

    赛通的一个很好的特点是 fused types ,因此您可以很容易地将此函数泛化为不同的类型。您的方法只适用于32位整数数组(至少如果 int 在您的计算机上是32位)。很容易支持更多的数组类型:

    ctypedef fused int_or_float:
        cnp.int32_t
        cnp.int64_t
        cnp.float32_t
        cnp.float64_t
    
    cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
        cdef cnp.ndarray[int_or_float] limited
        limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
        return np.random.choice(limited)
    

    当然,如果需要,可以添加更多类型。其优点是新版本在旧版本失败的情况下工作:

    >>> search_1(np.arange(100, dtype=np.float_), 10, 20)
    ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
    >>> search_2(np.arange(100, dtype=np.float_), 10, 20)
    19.0
    

    现在更一般了,让我们来看看你的函数实际上做了什么:

    • 创建一个布尔数组,其中元素高于下限
    • 创建一个布尔数组,其中元素低于上限
    • 通过两个布尔数组中的位和创建布尔数组。
    • 创建一个只包含布尔值掩码为真的元素的新数组
    • 只能从最后一个数组中提取一个元素

    为什么要创建这么多数组?我的意思是,你可以简单地计算在限制范围内有多少个元素,取一个介于0和限制范围内元素数量之间的随机整数,然后取任何元素。 会是 在结果数组的那个索引处。

    cimport cython
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
        cdef int_or_float element
    
        # Count the number of elements that are within the limits
        cdef Py_ssize_t num_valid = 0
        for index in range(arr.shape[0]):
            element = arr[index]
            if lower_bound <= element <= upper_bound:
                num_valid += 1
    
        # Take a random index
        cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
    
        # Go through the array again and take the element at the random index that
        # is within the bounds
        cdef Py_ssize_t clamped_index = 0
        for index in range(arr.shape[0]):
            element = arr[index]
            if lower_bound <= element <= upper_bound:
                if clamped_index == random_index:
                    return element
                clamped_index += 1
    

    它不会快很多,但会节省很多内存。因为你没有中间数组,你根本不需要内存视图,但是如果你愿意,你可以替换 cnp.ndarray[int_or_float] arr 在参数列表中 int_or_float[:] 甚至 int_or_float[::1] arr 并对memoryview进行操作(可能不会更快,但也不会更慢)。

    我通常喜欢numba而不是cython(至少如果我只是在使用它),所以让我们将其与该代码的numba版本进行比较:

    import numba as nb
    import numpy as np
    
    @nb.njit
    def search_numba(arr, lower, upper):
        num_valids = 0
        for item in arr:
            if item >= lower and item <= upper:
                num_valids += 1
    
        random_index = np.random.randint(0, num_valids)
    
        valid_index = 0
        for item in arr:
            if item >= lower and item <= upper:
                if valid_index == random_index:
                    return item
                valid_index += 1
    

    以及 numexpr 变体:

    import numexpr
    
    np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
    

    好吧,让我们做一个基准:

    from simple_benchmark import benchmark, MultiArgument
    
    arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
    funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
    
    b = benchmark(funcs, arguments, argument_name='array size')
    

    enter image description here

    因此,如果不使用中间数组,您可以快5倍左右,如果您要使用numba,您可以得到另一个因子5(似乎我在那里缺少了一些可能的cython优化,numba通常快2倍或与cython一样快)。因此,使用Numba解决方案,您可以更快地获得20倍。

    数字表达式 在这里并不具有可比性,主要是因为您不能在那里使用布尔数组索引。

    差异将取决于数组的内容和限制。您还必须测量应用程序的性能。


    作为旁白:如果下限和上限一般不改变,最快的解决方案是过滤数组一次,然后调用 np.random.choice 好几次。那可能是 数量级更快 .

    lower_limit = ...
    upper_limit = ...
    filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    
    def search_cached():
        return np.random.choice(filtered_array)
    
    %timeit search_cached()
    2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    所以速度快了近1000倍,根本不需要赛通或者麻木。但这是一个特殊的案例,可能对你没有帮助。


    如果您想自己动手,基准设置就在这里(基于Jupyter笔记本/实验室,因此 % -符号):

    %load_ext cython
    
    %%cython
    
    cimport numpy as cnp
    import numpy as np
    
    cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
        cdef cnp.ndarray[int] limited
        limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
        return np.random.choice(limited)
    
    ctypedef fused int_or_float:
        cnp.int32_t
        cnp.int64_t
        cnp.float32_t
        cnp.float64_t
    
    cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
        cdef cnp.ndarray[int_or_float] limited
        limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
        return np.random.choice(limited)
    
    cimport cython
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
        cdef int_or_float element
        cdef Py_ssize_t num_valid = 0
        for index in range(arr.shape[0]):
            element = arr[index]
            if lower_bound <= element <= upper_bound:
                num_valid += 1
    
        cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
    
        cdef Py_ssize_t clamped_index = 0
        for index in range(arr.shape[0]):
            element = arr[index]
            if lower_bound <= element <= upper_bound:
                if clamped_index == random_index:
                    return element
                clamped_index += 1
    
    import numexpr
    import numba as nb
    import numpy as np
    
    def search_numexpr(arr, l, u):
        return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
    
    @nb.njit
    def search_numba(arr, lower, upper):
        num_valids = 0
        for item in arr:
            if item >= lower and item <= upper:
                num_valids += 1
    
        random_index = np.random.randint(0, num_valids)
    
        valid_index = 0
        for item in arr:
            if item >= lower and item <= upper:
                if valid_index == random_index:
                    return item
                valid_index += 1
    
    from simple_benchmark import benchmark, MultiArgument
    
    arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
    funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
    
    b = benchmark(funcs, arguments, argument_name='array size')
    
    %matplotlib widget
    
    import matplotlib.pyplot as plt
    
    plt.style.use('ggplot')
    b.plot()
    
    推荐文章