代码之家  ›  专栏  ›  技术社区  ›  b-fg

二维元素的Numpy矩阵乘法

  •  1
  • b-fg  · 技术社区  · 6 年前

    我有一个 a 努比 ndarray

    a =  ([[ uu, uv, uw],
           [ uv, vv, vw],
           [ uw, vw, ww]])
    

    每个组件本身就是一个大小为的二维数组 (N,M) ,所以 矩阵有一个 (3,3,N,M)

    我怎样才能执行 a*a 以蟒蛇的方式? a@a 抛出以下错误(对于N=1218和M=540):

    (尺寸3)!=1218(尺寸2)

    我希望能够执行这个操作,就像 其中只有标量值 不会抛出与其形状相关的错误,因为它是一个简单的3x3矩阵乘法。

    谢谢。

    1 回复  |  直到 6 年前
        1
  •  1
  •   Divakar    6 年前

    假设您希望沿最后两个轴对每个元素执行矩阵乘法,我们可以使用 np.einsum -

    np.einsum('ijkl,jmkl->imkl',a,a)
    

    样品运行验证-

    In [43]: np.random.seed(0)
    
    In [44]: a = np.random.rand(3,3,4,5)
    
    In [45]: a[:,:,0,0].dot(a[:,:,0,0])
    Out[45]: 
    array([[0.71750146, 1.17057872, 1.11135764],
           [0.62938365, 0.86437796, 0.74541383],
           [1.04636618, 1.62011127, 1.35483565]])
    
    In [46]: np.einsum('ijkl,jmkl->imkl',a,a)[:,:,0,0]
    Out[46]: 
    array([[0.71750146, 1.17057872, 1.11135764],
           [0.62938365, 0.86437796, 0.74541383],
           [1.04636618, 1.62011127, 1.35483565]])