代码之家  ›  专栏  ›  技术社区  ›  Emanuele Paolini

numpy:累积多重性计数

  •  3
  • Emanuele Paolini  · 技术社区  · 7 年前

    我有一个排序的int数组,其中可能有重复。我想计算连续相等的值,当一个值与前一个值不同时,从零重新开始。这是使用简单python循环实现的预期结果:

    import numpy as np
    
    def count_multiplicities(a):
        r = np.zeros(a.shape, dtype=a.dtype)
        for i in range(1, len(a)):
            if a[i] == a[i-1]:
                r[i] = r[i-1]+1
            else:
                r[i] = 0
        return r
    
    a = (np.random.rand(20)*5).astype(dtype=int)
    a.sort()
    
    print "given sorted array: ", a
    print "multiplicity count: ", count_multiplicities(a)
    

    given sorted array:  [0 0 0 0 0 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4]
    multiplicity count:  [0 1 2 3 4 0 1 2 0 1 2 3 0 1 2 3 0 1 2 3]
    

    如何使用numpy以有效的方式获得相同的结果?数组很长,但重复次数很少(比如不超过十次)。

    2 回复  |  直到 7 年前
        1
  •  3
  •   Divakar    7 年前

    这里有一个 cumsum

    def count_multiplicities_cumsum_vectorized(a):      
        out = np.ones(a.size,dtype=int)
        idx = np.flatnonzero(a[1:] != a[:-1])+1
        out[idx[0]] = -idx[0] + 1
        out[0] = 0
        out[idx[1:]] = idx[:-1] - idx[1:] + 1
        np.cumsum(out, out=out)
        return out
    

    In [58]: a
    Out[58]: array([0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4])
    
    In [59]: count_multiplicities(a) # Original approach
    Out[59]: array([0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2])
    
    In [60]: count_multiplicities_cumsum_vectorized(a)
    Out[60]: array([0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2])
    

    In [66]: a = (np.random.rand(200000)*1000).astype(dtype=int)
        ...: a.sort()
        ...: 
    
    In [67]: a
    Out[67]: array([  0,   0,   0, ..., 999, 999, 999])
    
    In [68]: %timeit count_multiplicities(a)
    10 loops, best of 3: 87.2 ms per loop
    
    In [69]: %timeit count_multiplicities_cumsum_vectorized(a)
    1000 loops, best of 3: 739 µs per loop
    

    Related post .

        2
  •  1
  •   max9111    7 年前

    我会用麻木来解决这些问题

    import numba
    nb_count_multiplicities = numba.njit("int32[:](int32[:])")(count_multiplicities)
    X=nb_count_multiplicities(a)
    

    在完全不重写代码的情况下,它比Divakar的矢量化解决方案快约50%。