代码之家  ›  专栏  ›  技术社区  ›  Sairus

朱莉娅:如何编写修改结构域的快速函数?

  •  2
  • Sairus  · 技术社区  · 7 年前

    我想写一些有效的方法来处理一些数据结构中的矩阵。我测试了外积的两个相同函数,一个在普通矩阵上操作,另一个在结构域上操作。第二个功能运行速度慢约25倍:

    mutable struct MyMatrix{T<:Real}
        mtx::Array{T}
        MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
    end
    
    function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
        # mtx = M.mtx - using local reference doesn't help
        len1 = length(x1)
        len2 = length(x2)
        size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
        for c=1:len2, r=1:len1
            M.mtx[r,c] = x1[r]*x2[c]
        end
        M
    end
    
    function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
        len1 = length(x1)
        len2 = length(x2)
        size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
        for c=1:len2, r=1:len1
            mtx[r,c] = x1[r]*x2[c]
        end
        mtx
    end
    
    N = 100;
    v1 = collect(Float64, 1:N)
    v2 = collect(Float64, N:-1:1)
    m = Array{Float64}(100,100)
    M = MyMatrix{Float64}(100)
    
    @time outerprod!(M,v1,v2);
    >>  0.001334 seconds (10.00 k allocations: 156.406 KiB)
    
    @time outerprod!(m,v1,v2);
    >>  0.000055 seconds (4 allocations: 160 bytes)
    

    最后,当我编写第三个版本时,引用了fast函数,它在结构上运行得同样快:

    function outerprod_!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
        outerprod!(M.mtx, x1, x2)
        M
    end
    
    @time outerprod_!(M,v1,v2);
    >>  0.000058 seconds (4 allocations: 160 bytes)
    

    第一个函数有什么问题?

    P、 美国在这个问题上苦苦挣扎了一段时间,在julia中进行了不同的优化,最终发现了这个问题。

    1 回复  |  直到 7 年前
        1
  •  1
  •   giordano    7 年前

    主要问题是 Array{<:Real} 不是具体类型:

    julia> Array{<:Real}
    Array{#s29,N} where N where #s29<:Real
    

    这种类型包括任何可能的 N 相反,你真的对矩阵感兴趣,所以它应该是 Array{T, 2} ,或者更容易键入和理解, Matrix{T} . 此外,请注意 MyMatrix 类型可以是不可变的:在不可变结构中,不能设置字段,但如果字段本身是可变的,则可以设置其内部字段。此外 for -环路可以通过使用 @inbounds :

    struct MyMatrix{T<:Real}
        mtx::Matrix{T}
        MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
    end
    
    function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
        # mtx = M.mtx - using local reference doesn't help
        len1 = length(x1)
        len2 = length(x2)
        size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
        @inbounds for c=1:len2, r=1:len1
            M.mtx[r,c] = x1[r]*x2[c]
        end
        M
    end
    
    function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
        len1 = length(x1)
        len2 = length(x2)
        size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
        @inbounds for c=1:len2, r=1:len1
            mtx[r,c] = x1[r]*x2[c]
        end
        mtx
    end
    
    N = 100;
    v1 = collect(Float64, 1:N)
    v2 = collect(Float64, N:-1:1)
    m = Matrix{Float64}(100,100)
    M = MyMatrix{Float64}(100)
    

    测试速度:

    julia> using BenchmarkTools
    
    julia> @btime outerprod!(m,v1,v2);
      2.746 μs (0 allocations: 0 bytes)
    
    julia> @btime outerprod!(M,v1,v2);
      2.746 μs (0 allocations: 0 bytes)