代码之家  ›  专栏  ›  技术社区  ›  galah92

4矩阵乘法的np.einsum性能

  •  1
  • galah92  · 技术社区  · 7 年前

    给出以下3个矩阵:

    M = np.arange(35 * 37 * 59).reshape([35, 37, 59])
    A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
    B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
    C = np.arange(59 * 27).reshape([59, 27])
    

    我在用 einsum 计算:

    D1 = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize=True);
    

    但我发现那时候的表现要差得多:

    tmp = np.einsum('xyf,xtf->tfy', A, M, optimize=True)
    tmp = np.einsum('ytpf,yft->ftp', B, tmp, optimize=True)
    D2 = np.einsum('fr,ftp->tpr', C, tmp, optimize=True)
    

    我不明白为什么。
    总的来说,我正在尽可能优化这段代码我读过关于 np.tensordot 函数,但我似乎不知道如何利用它进行给定的计算。

    1 回复  |  直到 7 年前
        1
  •  3
  •   Daniel    7 年前

    你好像偶然发现一个案子 greedy 路径给出了一个非最优缩放。

    >>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy");
    >>> print(desc)
      Complete contraction:  xyf,xtf,ytpf,fr->tpr
             Naive scaling:  6
         Optimized scaling:  5
          Naive FLOP count:  3.219e+10
      Optimized FLOP count:  4.165e+08
       Theoretical speedup:  77.299
      Largest intermediate:  5.371e+06 elements
    --------------------------------------------------------------------------
    scaling                  current                                remaining
    --------------------------------------------------------------------------
       5              ytpf,xyf->xptf                         xtf,fr,xptf->tpr
       4               xptf,xtf->ptf                              fr,ptf->tpr
       4                 ptf,fr->tpr                                 tpr->tpr
    
    >>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal");
    >>> print(desc)
      Complete contraction:  xyf,xtf,ytpf,fr->tpr
             Naive scaling:  6
         Optimized scaling:  4
          Naive FLOP count:  3.219e+10
      Optimized FLOP count:  2.744e+07
       Theoretical speedup:  1173.425
      Largest intermediate:  1.535e+05 elements
    --------------------------------------------------------------------------
    scaling                  current                                remaining
    --------------------------------------------------------------------------
       4                xtf,xyf->ytf                         ytpf,fr,ytf->tpr
       4               ytf,ytpf->ptf                              fr,ptf->tpr
       4                 ptf,fr->tpr                                 tpr->tpr
    

    使用 np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal") 你应该以最好的成绩跑我可以看看这个边缘,看看贪婪是否能抓住它。

    推荐文章