代码之家  ›  专栏  ›  技术社区  ›  splinter

当函数包含条件时,使用Numpy将函数应用于数组

  •  0
  • splinter  · 技术社区  · 6 年前

    当函数包含条件时,我很难将函数应用于数组。我有一个低效的解决方法,正在寻找一个高效(快速)的方法。在一个简单的例子中:

    pts = np.linspace(0,1,11)
    def fun(x, y):
        if x > y:
            return 0
        else:
            return 1
    

    现在,如果我跑步:

    result = fun(pts, pts)
    

    ValueError:包含多个元素的数组的真值不明确。使用a.any()或a.all()

    在家长大 if x > y 线路。我的低效变通方法给出了正确的结果,但速度太慢:

    result = np.full([len(pts)]*2, np.nan)
    for i in range(len(pts)):
        for j in range(len(pts)):
            result[i,j] = fun(pts[i], pts[j])
    

    以更好(更重要的是,更快)的方式实现这一目标的最佳方式是什么?

    当函数包含条件时,我很难将函数应用于数组。我有一个低效的解决方法,正在寻找一个高效(快速)的方法。在一个简单的例子中:

    pts=np.linspace(0,1,11)
    def fun(x,y):
    如果x>y:
    返回0
    

    现在,如果我跑步:

    结果=乐趣(分数,分数)
    

    然后我得到了错误

    在家长大 如果x>Y

    结果=np.full([len(pts)]*2,np.nan)
    对于范围内的i(len(pts)):
    

    编辑 :使用

    def fun(x, y):
        if x > y:
            return 0
        else:
            return 1
    x = np.array(range(10))
    y = np.array(range(10))
    xv,yv = np.meshgrid(x,y)
    result = fun(xv, yv)  
    

    ValueError .

    3 回复  |  直到 6 年前
        1
  •  1
  •   kabanus    6 年前

    这个错误非常明显——假设您有

    x = np.array([1,2])
    y = np.array([2,1])
    

    以致

    (x>y) == np.array([0,1])
    

    你的研究结果应该是什么 if np.array([0,1]) 陈述是真是假? numpy

    (x>y).all()
    

    (x>y).any()
    

    是明确的,因此 努比 为您提供解决方案-任何细胞对满足条件,或所有细胞对满足条件-都是明确的真值。你必须为自己准确地定义你的意思 向量x大于向量y .

    这个 x y x[i]>y[j] 要使用网格网格生成所有对,请执行以下操作:

    >>> import numpy as np
    >>> x=np.array(range(10))
    >>> y=np.array(range(10))
    >>> xv,yv=np.meshgrid(x,y)
    >>> xv[xv>yv]
    array([1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8,
           9, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 6, 7, 8, 9, 7, 8, 9, 8, 9, 9])
    >>> yv[xv>yv]
    array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
           2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8])
    

    要么发送 xv yv fun xi,yj 以致 xi>yj . 如果需要实际的索引,只需返回 xv>yv ij 对应 x[i] y[j] . 就你而言:

    def fun(x, y):
        xv,yv=np.meshgrid(x,y)
        return xv>yv
    

    fun(x,y)[i][j] 如果 x[i]>y[j] ,否则为假。或者

    return  np.where(xv>yv)
    

    将返回两个索引对数组的元组,这样

    for i,j in fun(x,y):
    

    保证 x[i]>y[j]

        2
  •  1
  •   hpaulj    6 年前
    In [253]: x = np.random.randint(0,10,5)
    In [254]: y = np.random.randint(0,10,5)
    In [255]: x
    Out[255]: array([3, 2, 2, 2, 5])
    In [256]: y
    Out[256]: array([2, 6, 7, 6, 5])
    In [257]: x>y
    Out[257]: array([ True, False, False, False, False])
    In [258]: np.where(x>y,0,1)
    Out[258]: array([0, 1, 1, 1, 1])
    

    要与这两个一维阵列进行笛卡尔比较,请重塑其中一个阵列,以便它可以使用 broadcasting

    In [259]: x[:,None]>y
    Out[259]: 
    array([[ True, False, False, False, False],
           [False, False, False, False, False],
           [False, False, False, False, False],
           [False, False, False, False, False],
           [ True, False, False, False, False]])
    In [260]: np.where(x[:,None]>y,0,1)
    Out[260]: 
    array([[0, 1, 1, 1, 1],
           [1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1],
           [0, 1, 1, 1, 1]])
    

    if a>b 生成布尔数组,该数组不能用于 如果 陈述迭代之所以有效,是因为它传递标量值。对于一些复杂的函数,这是您所能做的最好的( np.vectorize

    where 将布尔数组映射到所需的1/0上的工作做得很好。还有其他方法可以实现这种映射。

    None

        3
  •  1
  •   max9111    6 年前

    对于一个更复杂的示例,或者如果处理的数组有点大,或者可以写入已经分配的数组,则可以考虑。 Numba

    例子

    import numba as nb
    import numpy as np
    
    @nb.njit()
    def fun(x, y):
      if x > y:
        return 0
      else:
        return 1
    
    @nb.njit(parallel=False)
    #@nb.njit(parallel=True)
    def loop(x,y):
      result=np.empty((x.shape[0],y.shape[0]),dtype=np.int32)
      for i in nb.prange(x.shape[0]):
        for j in range(y.shape[0]):
          result[i,j] = fun(x[i], y[j])
      return result
    
    @nb.njit(parallel=False)
    def loop_preallocated(x,y,result):
      for i in nb.prange(x.shape[0]):
        for j in range(y.shape[0]):
          result[i,j] = fun(x[i], y[j])
      return result
    

    时间安排

    x = np.array(range(1000))
    y = np.array(range(1000))
    
    #Compilation overhead of the first call is neglected
    
    res=np.where(x[:,None]>y,0,1) -> 2.46ms
    loop(single_threaded)         -> 1.23ms
    loop(parallel)                -> 1.0ms
    loop(single_threaded)*        -> 0.27ms
    loop(parallel)*               -> 0.058ms
    

    *可能受缓存的影响。测试你自己的例子。