代码之家  ›  专栏  ›  技术社区  ›  ead

使用赛通实现Numba的性能

  •  4
  • ead  · 技术社区  · 6 年前

    通常,我可以在使用赛通时与numba的表现相匹配。然而,在这个例子中,我没有做到这一点——numba比我的cython版本快4倍。

    这里是赛通版本:

    %%cython -c=-march=native -c=-O3
    cimport numpy as np
    import numpy as np
    cimport cython
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    def cy_where(double[::1] df):
        cdef int i
        cdef int n = len(df)
        cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
        for i in range(n):
            if df[i]>0.5:
                output[i] = 2.0*df[i]
            else:
                output[i] = df[i]
        return output 
    

    这里是numba版本:

    import numba as nb
    @nb.njit
    def nb_where(df):
        n = len(df)
        output = np.empty(n, dtype=np.float64)
        for i in range(n):
            if df[i]>0.5:
                output[i] = 2.0*df[i]
            else:
                output[i] = df[i]
        return output
    

    测试时,赛通的版本与numpy的一样。 where 但明显不如Numba:

    #Python3.6 + Cython 0.28.3 + gcc-7.2
    import numpy
    np.random.seed(0)
    n = 10000000
    data = np.random.random(n)
    
    assert (cy_where(data)==nb_where(data)).all()
    assert (np.where(data>0.5,2*data, data)==nb_where(data)).all()
    
    %timeit cy_where(data)       # 179ms
    %timeit nb_where(data)       # 49ms (!!)
    %timeit np.where(data>0.5,2*data, data)  # 278 ms
    

    Numba表现的原因是什么?在使用Cython时如何匹配?


    正如@max9111所建议的,使用连续内存视图来消除步幅,这并不能显著提高性能:

    @cython.boundscheck(False)
    @cython.wraparound(False)
    def cy_where_cont(double[::1] df):
        cdef int i
        cdef int n = len(df)
        cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
        cdef double[::1] view = output  # view as continuous!
        for i in range(n):
            if df[i]>0.5:
                view[i] = 2.0*df[i]
            else:
                view[i] = df[i]
        return output 
    
    %timeit cy_where_cont(data)   #  165 ms
    
    2 回复  |  直到 6 年前
        1
  •  2
  •   chrisb    6 年前

    $ CC=clang ipython
    <... setup code>
    
    In [7]: %timeit cy_where(data)       # 179ms
       ...: %timeit nb_where(data)       # 49ms (!!) 
    
    30.8 ms ± 309 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    30.2 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
        2
  •  0
  •   serge-sans-paille    6 年前

    pythran

    import numpy as np
    #pythran export work(float64[])
    
    def work(df):
        return np.where(data>0.5,2*data, data)
    

    CXX=clang++ CC=clang pythran pythran_work.py -O3 -march=native
    

    import numpy as np
    np.random.seed(0)
    n = 10000000
    data = np.random.random(n)
    import numba_work, pythran_work
    
    %timeit numba_work.work(data)
    12.7 ms ± 20 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    %timeit pythran_work.work(data)
    12.7 ms ± 32.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    推荐文章