For reference, a numerical example (where I forgot the burn-in, so take the first estimate runtime with a grain of salt…)
using Random, Distributions, LinearAlgebra, IterativeSolvers, LinearMaps, BenchmarkTools
Random.seed!(42)
function L(b::AbstractMatrix, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(b, X, Z)
N = size(X,1)
Y = zeros(promote_type(eltype(b), eltype(X), eltype(Z)), N)
@inbounds for k in axes(X, 2), ℓ in axes(Z, 2);
@simd for i in eachindex(Y)
Y[i] += b[ℓ,k] * X[i,k] * Z[i,ℓ]
end
end
return Y
end
function L!(Y::AbstractVector, b::AbstractMatrix, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(b, X, Z)
Y .= 0
@inbounds for k in axes(X, 2), ℓ in axes(Z, 2);
@simd for i in eachindex(Y)
Y[i] += b[ℓ,k] * X[i,k] * Z[i,ℓ]
end
end
return Y
end
function Lᵀ(Y::AbstractVector, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(Y, X, Z)
n, m = size(X,2), size(Z,2)
b = Matrix{promote_type(eltype(Y), eltype(X), eltype(Z))}(undef, m,n)
@inbounds for k in axes(X, 2), ℓ in axes(Z, 2)
b_ℓk = zero(eltype(Y))
@simd for i in eachindex(Y)
b_ℓk += Y[i] * X[i,k] * Z[i,ℓ]
end
b[ℓ,k] = b_ℓk
end
return b
end
function Lᵀ!(b::AbstractMatrix, Y::AbstractVector, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(Y, X, Z)
@inbounds for k in axes(X, 2), ℓ in axes(Z, 2)
b_ℓk = zero(eltype(b))
@simd for i = eachindex(Y)
b_ℓk += Y[i] * X[i,k] * Z[i,ℓ]
end
b[ℓ,k] = b_ℓk
end
return b
end
function test(N, d, k)
X = rand(N, d)
Z = rand(N, k)
XZ = Matrix{Float64}(undef, N, d*k)
@time for di in 1:d
XZ[:, 1 + (di - 1)*k:di*k] .= view(X, :, di) .* Z
end
β = rand(d*k)
Y = XZ * β + randn(N)
vec_l = 10
d_vec = floor.(Int, range(1, d, vec_l))
d_runtimes = Vector{Float64}(undef, vec_l)
d_runtimes2 = Vector{Float64}(undef, vec_l)
mse_est = Vector{Float64}(undef, vec_l)
mse_est_2 = Vector{Float64}(undef, vec_l)
d_runtimes3 = Vector{Float64}(undef, vec_l)
mse_est_3 = Vector{Float64}(undef, vec_l)
d_runtimes4 = Vector{Float64}(undef, vec_l)
mse_est_4 = Vector{Float64}(undef, vec_l)
for (i, dsub) in enumerate(d_vec)
X_sub = X[:, 1:dsub]
XZ_sub = Matrix{Float64}(undef, N, dsub*k)
for di in 1:dsub
XZ_sub[:, 1 + (di - 1)*k:di*k] .= view(X_sub, :, di) .* Z
end
println("Estimating $i using inverse")
β̂ = @timed XZ_sub \ Y
println("Took $(β̂.time) seconds")
mse_est[i] = mean((XZ_sub * β̂.value - Y).^2)
d_runtimes[i] = β̂.time
println("Estimating $i using lsqr")
op = FunctionMap{Float64,false}(bvec -> L(reshape(bvec,k,dsub), X_sub, Z), Y -> vec(Lᵀ(Y, X_sub, Z)), N, k*dsub)
op_inplace = FunctionMap{Float64,true}((Y, bvec) -> L!(Y, reshape(bvec, k, dsub), X_sub, Z), (bvec, Y) -> vec(Lᵀ!(reshape(bvec, k, dsub), Y, X_sub, Z)), N, k*dsub)
β̂ = @timed reshape(lsqr(op, Y), k, dsub)
println("Took $(β̂.time) seconds")
mse_est_2[i] = mean((L(reshape(β̂.value,k,dsub), X_sub, Z) - Y).^2)
d_runtimes2[i] = β̂.time
println("Estimating $i using lsmr")
β̂ = @timed reshape(lsmr(op, Y), k, dsub)
println("Took $(β̂.time) seconds")
mse_est_3[i] = mean((L(reshape(β̂.value,k,dsub), X_sub, Z) - Y).^2)
d_runtimes3[i] = β̂.time
println("Estimating $i using lsmr inplace")
β̂ = @timed reshape(lsmr(op_inplace, Y), k, dsub)
println("Took $(β̂.time) seconds")
mse_est_4[i] = mean((L(reshape(β̂.value,k,dsub), X_sub, Z) - Y).^2)
d_runtimes4[i] = β̂.time
end
println(mse_est)
println(mse_est_2)
println(mse_est_3)
println(mse_est_4)
println(d_runtimes)
println(d_runtimes2)
println(d_runtimes3)
println(d_runtimes4)
end
test(100_000, 50, 100)
gives
0.379931 seconds
Estimating 1 using inverse
Took 0.228198375 seconds
Estimating 1 using lsqr
Took 0.20376975 seconds
Estimating 1 using lsmr
Took 0.114865875 seconds
Estimating 1 using lsmr inplace
Took 0.12395125 seconds
Estimating 2 using inverse
Took 3.092642 seconds
Estimating 2 using lsqr
Took 0.784247291 seconds
Estimating 2 using lsmr
Took 0.615687167 seconds
Estimating 2 using lsmr inplace
Took 0.615284 seconds
Estimating 3 using inverse
Took 9.574711583 seconds
Estimating 3 using lsqr
Took 2.021692375 seconds
Estimating 3 using lsmr
Took 1.425100083 seconds
Estimating 3 using lsmr inplace
Took 1.395683833 seconds
Estimating 4 using inverse
Took 21.734518958 seconds
Estimating 4 using lsqr
Took 3.4526595 seconds
Estimating 4 using lsmr
Took 2.303357708 seconds
Estimating 4 using lsmr inplace
Took 2.277328375 seconds
Estimating 5 using inverse
Took 37.174790166 seconds
Estimating 5 using lsqr
Took 5.026946708 seconds
Estimating 5 using lsmr
Took 3.495458833 seconds
Estimating 5 using lsmr inplace
Took 3.433379792 seconds
Estimating 6 using inverse
Took 57.444847375 seconds
Estimating 6 using lsqr
Took 6.52967775 seconds
Estimating 6 using lsmr
Took 4.626406417 seconds
Estimating 6 using lsmr inplace
Took 4.491984125 seconds
Estimating 7 using inverse
Took 79.451948959 seconds
Estimating 7 using lsqr
Took 8.227401208 seconds
Estimating 7 using lsmr
Took 5.232909541 seconds
Estimating 7 using lsmr inplace
Took 5.240785166 seconds
Estimating 8 using inverse
Took 110.555767958 seconds
Estimating 8 using lsqr
Took 9.65445625 seconds
Estimating 8 using lsmr
Took 6.07731025 seconds
Estimating 8 using lsmr inplace
Took 6.072688333 seconds
Estimating 9 using inverse
Took 136.001674292 seconds
Estimating 9 using lsqr
Took 11.162575667 seconds
Estimating 9 using lsmr
Took 6.841667083 seconds
Estimating 9 using lsmr inplace
Took 6.843795709 seconds
Estimating 10 using inverse
Took 177.454427917 seconds
Estimating 10 using lsqr
Took 15.239018417 seconds
Estimating 10 using lsmr
Took 10.854387917 seconds
Estimating 10 using lsmr inplace
Took 11.097065416 seconds
[97486.16854910454, 18328.43239810469, 9018.555817126977, 4893.845608371163, 3189.6915435352157, 1989.801489581642, 1294.25192687956, 706.1769890062392, 347.7741537705133, 0.9500232830021116]
[97486.16854910481, 18328.432398146768, 9018.555817133984, 4893.845608435037, 3189.691543539357, 1989.8014896004731, 1294.2519268813387, 706.1769890090135, 347.7741537748462, 0.9500232830059561]
[97486.16854936874, 18328.432408960995, 9018.55590730859, 4893.845862939049, 3189.691580688102, 1989.8015190086182, 1294.2519695989643, 706.1770434840151, 347.77421115991575, 0.9500233151409269]
[97486.16854936874, 18328.432408960995, 9018.55590730859, 4893.845862939049, 3189.691580688102, 1989.8015190086182, 1294.2519695989643, 706.1770434840151, 347.77421115991575, 0.9500233151409269]
[0.228198375, 3.092642, 9.574711583, 21.734518958, 37.174790166, 57.444847375, 79.451948959, 110.555767958, 136.001674292, 177.454427917]
[0.20376975, 0.784247291, 2.021692375, 3.4526595, 5.026946708, 6.52967775, 8.227401208, 9.65445625, 11.162575667, 15.239018417]
[0.114865875, 0.615687167, 1.425100083, 2.303357708, 3.495458833, 4.626406417, 5.232909541, 6.07731025, 6.841667083, 10.854387917]
[0.12395125, 0.615284, 1.395683833, 2.277328375, 3.433379792, 4.491984125, 5.240785166, 6.072688333, 6.843795709, 11.097065416]