代码之家  ›  专栏  ›  技术社区  ›  Baba Dan Constantin

SSE4.1在矩阵4x4乘法上比SSE3慢?

  •  1
  • Baba Dan Constantin  · 技术社区  · 4 月前

    所以我有一个矩阵乘法的SSE3实现:

    /**
     * Loop is unwraped for performance
     * @attention As opposed to non-SIMD multiplication we're using column-major
     */
    inline void multiply(const float *__restrict affector, const float *__restrict affected, float *__restrict result)
    {
      // std::cout << "INSIDE SS3" << std::endl;
      // Load rows of matrix B and transpose to columns
      __m128 a0 = _mm_load_ps(&affector[0]);
      __m128 a1 = _mm_load_ps(&affector[4]);
      __m128 a2 = _mm_load_ps(&affector[8]);
      __m128 a3 = _mm_load_ps(&affector[12]);
    
      // Load rows of matrix A
      __m128 b0 = _mm_load_ps(&affected[0]);
      __m128 b1 = _mm_load_ps(&affected[4]);
      __m128 b2 = _mm_load_ps(&affected[8]);
      __m128 b3 = _mm_load_ps(&affected[12]);
    
      // b0 = [1, 2, 3, 4]
      // b1 = [5, 6, 7, 8]
      // b2 = [9, 10, 11, 12]
      // b3 = [13, 14, 15, 16]
      // need to arrive at-> b0 = [1, 5, 9, 10]
      // need to arrive at-> b1 = [2, 6, 10, 14]
      // need to arrive at-> b2 = [3, 7, 11, 15]
      // need to arrive at-> b3 = [4, 8, 12, 16]
    
      // tmp0 = [1, 5, 2, 6]
      __m128 tmp0 = _mm_unpacklo_ps(b0, b1);
      // tmp1 = [3, 7, 4, 8]
      __m128 tmp1 = _mm_unpackhi_ps(b0, b1);
      // tmp2 = [9, 13, 10, 14]
      __m128 tmp2 = _mm_unpacklo_ps(b2, b3);
      // tmp3 = [11, 15, 12, 16]
      __m128 tmp3 = _mm_unpackhi_ps(b2, b3);
    
      // b0 = [1, 5, ....] = move tmp2 low into tmp0 high
      b0 = _mm_movelh_ps(tmp0, tmp2);
      // b1 = [...., 10, 14] = move tmp0 high into tmp tmp2 low
      b1 = _mm_movehl_ps(tmp2, tmp0);
      // b2 = [3, 7, ....] = move tmp3 lows into tmp1 highs
      b2 = _mm_movelh_ps(tmp1, tmp3);
      // b3 = [...., 12, 16] = move tmp1 highs into tmp3 lows
      b3 = _mm_movehl_ps(tmp3, tmp1);
    
      // Need to perform dot product [x, y, z, d] * [1, 5, 9, 10]
      // This results in [x + 1, y + 5, z + 9, d + 10]
      __m128 mul = _mm_mul_ps(a0, b0);
      // Perform horizontal addition to sum of all of these values
      // This results in [x + 1 + y + 5, z + 9 + d + 10, 0.0, 0.0]
      mul = _mm_hadd_ps(mul, mul);
      // This results in [x + 1 + y + 5 + z + 9 + d + 10, 0.0, 0.0, 0.0]
      mul = _mm_hadd_ps(mul, mul);
      // Retrieve the result into result[0]
      result[0] = _mm_cvtss_f32(mul);
    
      // Perform the same for the rest of the matrix elements
    
      mul = _mm_mul_ps(a0, b1);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[1] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a0, b2);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[2] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a0, b3);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[3] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a1, b0);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[4] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a1, b1);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[5] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a1, b2);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[6] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a1, b3);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[7] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a2, b0);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[8] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a2, b1);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[9] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a2, b2);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[10] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a2, b3);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[11] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a3, b0);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[12] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a3, b1);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[13] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a3, b2);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[14] = _mm_cvtss_f32(mul);
    
      mul = _mm_mul_ps(a3, b3);
      mul = _mm_hadd_ps(mul, mul);
      mul = _mm_hadd_ps(mul, mul);
      result[15] = _mm_cvtss_f32(mul);
    }
    

    运行此功能1.000.000次,速度约为0.04秒 现在,我在想使用点积会加快速度,因为我不必:

    1. Multiply
    2. Do horizontal addition
    3. Do another horizontal addition
    

    而是仅仅:

    1. Single Dot product
    

    以下是SSE4.1的实现:

     * Loop is unwraped for performance
     * @attention As opposed to non-SIMD multiplication we're using column-major
     */
    inline void multiply(const float *__restrict affector, const float *__restrict affected, float *__restrict result)
    {
      // std::cout << "INSIDE SSE4" << std::endl;
      // Load rows of matrix B and transpose to columns
      __m128 a0 = _mm_load_ps(&affector[0]);
      __m128 a1 = _mm_load_ps(&affector[4]);
      __m128 a2 = _mm_load_ps(&affector[8]);
      __m128 a3 = _mm_load_ps(&affector[12]);
    
      // Load rows of matrix A
      __m128 b0 = _mm_load_ps(&affected[0]);
      __m128 b1 = _mm_load_ps(&affected[4]);
      __m128 b2 = _mm_load_ps(&affected[8]);
      __m128 b3 = _mm_load_ps(&affected[12]);
    
      // b0 = [1, 2, 3, 4]
      // b1 = [5, 6, 7, 8]
      // b2 = [9, 10, 11, 12]
      // b3 = [13, 14, 15, 16]
      // need to arrive at-> b0 = [1, 5, 9, 10]
      // need to arrive at-> b1 = [2, 6, 10, 14]
      // need to arrive at-> b2 = [3, 7, 11, 15]
      // need to arrive at-> b3 = [4, 8, 12, 16]
    
      // tmp0 = [1, 5, 2, 6]
      __m128 tmp0 = _mm_unpacklo_ps(b0, b1);
      // tmp1 = [3, 7, 4, 8]
      __m128 tmp1 = _mm_unpackhi_ps(b0, b1);
      // tmp2 = [9, 13, 10, 14]
      __m128 tmp2 = _mm_unpacklo_ps(b2, b3);
      // tmp3 = [11, 15, 12, 16]
      __m128 tmp3 = _mm_unpackhi_ps(b2, b3);
    
      // b0 = [1, 5, ....] = move tmp2 low into tmp0 high
      b0 = _mm_movelh_ps(tmp0, tmp2);
      // b1 = [...., 10, 14] = move tmp0 high into tmp tmp2 low
      b1 = _mm_movehl_ps(tmp2, tmp0);
      // b2 = [3, 7, ....] = move tmp3 lows into tmp1 highs
      b2 = _mm_movelh_ps(tmp1, tmp3);
      // b3 = [...., 12, 16] = move tmp1 highs into tmp3 lows
      b3 = _mm_movehl_ps(tmp3, tmp1);
    
      __m128 mul;
    
      // Perform the matrix multiplication for each element
      mul = _mm_dp_ps(a0, b0, 0xF1);  // Dot product of a0 and b0, 0xF1 means all four elements
      result[0] = _mm_cvtss_f32(mul); // Store result
    
      mul = _mm_dp_ps(a0, b1, 0xF1); // Dot product of a0 and b1
      result[1] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a0, b2, 0xF1); // Dot product of a0 and b2
      result[2] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a0, b3, 0xF1); // Dot product of a0 and b3
      result[3] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a1, b0, 0xF1); // Dot product of a1 and b0
      result[4] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a1, b1, 0xF1); // Dot product of a1 and b1
      result[5] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a1, b2, 0xF1); // Dot product of a1 and b2
      result[6] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a1, b3, 0xF1); // Dot product of a1 and b3
      result[7] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a2, b0, 0xF1); // Dot product of a2 and b0
      result[8] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a2, b1, 0xF1); // Dot product of a2 and b1
      result[9] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a2, b2, 0xF1); // Dot product of a2 and b2
      result[10] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a2, b3, 0xF1); // Dot product of a2 and b3
      result[11] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a3, b0, 0xF1); // Dot product of a3 and b0
      result[12] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a3, b1, 0xF1); // Dot product of a3 and b1
      result[13] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a3, b2, 0xF1); // Dot product of a3 and b2
      result[14] = _mm_cvtss_f32(mul);
    
      mul = _mm_dp_ps(a3, b3, 0xF1); // Dot product of a3 and b3
      result[15] = _mm_cvtss_f32(mul);
    }
    

    结果是:~0.15秒!!!这甚至比我的不使用内部函数的实现(~0.11-0.12秒)和使用SSE2的实现(-0.10-0.9秒)还要慢。发生什么事??是因为点积在较低级别的实现方式,还是我做错了什么? 所有3个(非intrisct、SSE2、SSE3、SSE4.1)都使用-O2进行了优化。

    1 回复  |  直到 4 月前
        1
  •  4
  •   Peter Cordes    4 月前

    haddps 也很慢,3 uops。除非你正在优化代码大小,而不是速度,否则不要对两个操作数使用相同的输入。(它有不同输入的转置和reduce用例。)

    但是 dpps 在最近的CPU上速度变慢了:Alder Lake P核上为6 uops,而SKL上为4 uops。( https://uops.info/ ). 最近的AMD也很糟糕,Zen 4有8个uops。你还没有说你在哪个CPU上测试过这个,但在某些情况下,是的,这是有道理的 dpps 版本甚至比 hadd 版本。


    无论如何,即使有最好的混洗,单独对每个元素进行水平求和仍然绝对不是进行4x4 matmul的最有效方法。

    在减少时,您应该将多个源向量混洗在一起(或垂直添加),以产生整行或整列的输出。 Efficient 4x4 matrix multiplication (C vs assembly) 显示了一个非常干净的版本,它将一个矩阵的所有4行加载到4个向量中,然后分别广播另一个矩阵中的每个标量元素,一次4个。

    在AVX1之前,广播加载没有单一指令 vbroadcastss ,所以只有SSE,最好的选择可能是矢量加载和4x _mm_shuffle_ps(v,v, _MM_SHUFFLE(0,0,0,0)) 1,1,1,1 等。或者,也许一个编译器可以为你做到这一点 _mm_set1_ps(A[4*i + 1]) 等。

    SSE3或SSE4指令均不适用于高效的4x4矩阵。