代码之家  ›  专栏  ›  技术社区  ›  Bill

有没有办法让python all()函数处理多维数组?

  •  0
  • Bill  · 技术社区  · 6 年前

    我正在尝试实现一个通用的和灵活的 __eq__ 方法,该方法用于尽可能多的对象类型,包括iterables和numpy数组。

    以下是我目前为止的情况:

    class Environment:
    
        def __init__(self, state):
            self.state = state
    
        def __eq__(self, other):
            """Compare two environments based on their states.
            """
            if isinstance(other, self.__class__):
                try:
                    return all(self.state == other.state)
                except TypeError:
                    return self.state == other.state
            return False
    

    这对大多数对象类型(包括一维数组)都适用:

    s = np.array(range(6))
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    
    s = 'abcdef'
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    

    问题是,当 self.state 是一个多维的numpy数组。

    s = np.array(range(6)).reshape((2, 3))
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2
    

    生产:

    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    很明显,我可以查一下 isinstance(other, np.ndarray) 然后做 (return self.state == other.state).all() 但我只是想有一种更通用的方法可以用一条语句处理所有的iterables、集合和任何类型的数组。

    我也有点困惑为什么 all() 不会像这样迭代数组的所有元素 array.all() . 有办法触发吗 np.nditer 可能会这样做?

    1 回复  |  直到 6 年前
        1
  •  0
  •   hpaulj    6 年前
    Signature: all(iterable, /)
    Docstring:
    Return True if bool(x) is True for all values x in the iterable.
    

    对于一维数组:

    In [200]: x=np.ones(3)                                                               
    In [201]: x                                                                          
    Out[201]: array([1., 1., 1.])
    In [202]: y = x==x                                                                   
    In [203]: y          # 1d array of booleans                                                                      
    Out[203]: array([ True,  True,  True])
    In [204]: bool(y[0])                                                                 
    Out[204]: True
    In [205]: all(y)                                                                     
    Out[205]: True
    

    对于二维数组:

    In [206]: x=np.ones((2,3))                                                           
    In [207]: x                                                                          
    Out[207]: 
    array([[1., 1., 1.],
           [1., 1., 1.]])
    In [208]: y = x==x                                                                   
    In [209]: y                                                                          
    Out[209]: 
    array([[ True,  True,  True],
           [ True,  True,  True]])
    In [210]: y[0]                                                                       
    Out[210]: array([ True,  True,  True])
    In [211]: bool(y[0])                                                                 
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-211-d0ce0868392c> in <module>
    ----> 1 bool(y[0])
    
    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    但对于不同的二维阵列:

    In [212]: x=np.ones((3,1))                                                           
    In [213]: y = x==x                                                                   
    In [214]: y                                                                          
    Out[214]: 
    array([[ True],
           [ True],
           [ True]])
    In [215]: y[0]                                                                       
    Out[215]: array([ True])
    In [216]: bool(y[0])                                                                 
    Out[216]: True
    In [217]: all(y)                                                                     
    Out[217]: True
    

    对numpy数组的迭代发生在第一个维度上。 [i for i in x]

    每当在需要标量布尔值的上下文中使用多值布尔值数组时,都会引发此歧义值错误。 if or/and 表达式是常见的。

    In [223]: x=np.ones((2,3))                                                           
    In [224]: y = x==x                                                                   
    In [225]: np.all(y)                                                                  
    Out[225]: True
    

    np.all 是不同的巨蟒 all 因为它“知道”尺寸。在这种情况下,它会 ravel 要将数组视为1d:

    缺省值(默认值) axis = None )是对输入数组的所有维度执行逻辑和。