Let’s define the linear operator L[b] = \left(\sum_{k,\ell} b_{\ell,k} X_{i,k} Z_{i,\ell}\right)_{i=1,\ldots,N} = \operatorname{diag}(ZBX^T). You could implement this in Julia with e.g.
function L(b::AbstractMatrix, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(b, X, Z)
m,n = size(b)
N = size(X,1)
(n == size(X,2) && m == size(Z,2) && N == size(Z,1)) || throw(DimensionMismatch())
Y = Vector{typeof(zero(eltype(b)) * zero(eltype(X)) * zero(eltype(Z)))}(undef, N)
@inbounds for i = 1:N
Yᵢ = zero(eltype(Y))
for k = 1:n; @simd for ℓ = 1:m
Yᵢ += b[ℓ,k] * X[i,k] * Z[i,ℓ]
end; end
Y[i] = Yᵢ
end
return Y
end
(You could also use something like Tullio.jl, but unfortunately that doesn’t work with the latest Julia. There are various other ways to implement this more compactly, but I wanted to avoid any unnecessary allocations since your matrices might be large.)
For least-squares, you will also need the transposed operator:
which could be implemented e.g. with:
function Lᵀ(Y::AbstractVector, X::AbstractMatrix, Z::AbstractMatrix)
Base.require_one_based_indexing(Y, X, Z)
n, m = size(X,2), size(Z,2)
N = length(Y)
N == size(X,1) == size(Z,1) || throw(DimensionMismatch())
b = Matrix{typeof(zero(eltype(Y)) * zero(eltype(X)) * zero(eltype(Z)))}(undef, m,n)
@inbounds for k = 1:n, ℓ = 1:m
b_ℓk = zero(eltype(Y))
@simd for i = 1:N
b_ℓk += Y[i] * conj(X[i,k] * Z[i,ℓ])
end
b[ℓ,k] = b_ℓk
end
return b
end
Solving the least-square problem is equivalent (modulo roundoff errors) to solving the normal equations L^T [L [b]] = L^T [y]. So, you just need to pass these linear operators to an iterative least-square algorithm (defined implicitly from the mapping functions with e.g. LinearMaps.jl), with reshape
and vec
to convert matrices to/from vectors.
For example, using IterativeSolvers.jl and some random data:
using IterativeSolvers, LinearMaps
m, n, N = 4, 3, 100
X, Z, Y = rand(N, n), rand(N, m), rand(N)
op = FunctionMap{Float64,false}(bvec -> L(reshape(bvec,m,n), X, Z), Y -> vec(Lᵀ(Y, X, Z)), N, m*n)
b = reshape(lsqr(op, Y), m, n)
(Caveat: the code above runs, but I haven’t tested it carefully.) Of course, you should put the above code into a function if you want to run it on a large-scale problem, as otherwise the global variables will slow things down.