主要问题是
Array{<:Real}
不是具体类型:
julia> Array{<:Real}
Array{#s29,N} where N where #s29<:Real
这种类型包括任何可能的
N
相反,你真的对矩阵感兴趣,所以它应该是
Array{T, 2}
,或者更容易键入和理解,
Matrix{T}
. 此外,请注意
MyMatrix
类型可以是不可变的:在不可变结构中,不能设置字段,但如果字段本身是可变的,则可以设置其内部字段。此外
for
-环路可以通过使用
@inbounds
:
struct MyMatrix{T<:Real}
mtx::Matrix{T}
MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
end
function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
# mtx = M.mtx - using local reference doesn't help
len1 = length(x1)
len2 = length(x2)
size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
@inbounds for c=1:len2, r=1:len1
M.mtx[r,c] = x1[r]*x2[c]
end
M
end
function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
len1 = length(x1)
len2 = length(x2)
size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
@inbounds for c=1:len2, r=1:len1
mtx[r,c] = x1[r]*x2[c]
end
mtx
end
N = 100;
v1 = collect(Float64, 1:N)
v2 = collect(Float64, N:-1:1)
m = Matrix{Float64}(100,100)
M = MyMatrix{Float64}(100)
测试速度:
julia> using BenchmarkTools
julia> @btime outerprod!(m,v1,v2);
2.746 μs (0 allocations: 0 bytes)
julia> @btime outerprod!(M,v1,v2);
2.746 μs (0 allocations: 0 bytes)