Best way for linear regression problem on product features

Let’s define the linear operator L[b] = \left(\sum_{k,\ell} b_{\ell,k} X_{i,k} Z_{i,\ell}\right)_{i=1,\ldots,N} = \operatorname{diag}(ZBX^T). You could implement this in Julia with e.g.

function L(b::AbstractMatrix, X::AbstractMatrix, Z::AbstractMatrix)
    Base.require_one_based_indexing(b, X, Z)
    m,n = size(b)
    N = size(X,1)
    (n == size(X,2) && m == size(Z,2) && N == size(Z,1)) || throw(DimensionMismatch())
    Y = Vector{typeof(zero(eltype(b)) * zero(eltype(X)) * zero(eltype(Z)))}(undef, N)
    @inbounds for i = 1:N
        Yᵢ = zero(eltype(Y))
        for k = 1:n; @simd for ℓ = 1:m
            Yᵢ += b[ℓ,k] * X[i,k] * Z[i,ℓ]
        end; end
        Y[i] = Yᵢ
    end
    return Y
end

(You could also use something like Tullio.jl, but unfortunately that doesn’t work with the latest Julia. There are various other ways to implement this more compactly, but I wanted to avoid any unnecessary allocations since your matrices might be large.)

For least-squares, you will also need the transposed operator:

L^T[Y] = \left(\sum_{i} Y_i X_{i,k} Z_{i,\ell}\right)_{\ell = 1\ldots m, k = 1\ldots n}

which could be implemented e.g. with:

function Lᵀ(Y::AbstractVector, X::AbstractMatrix, Z::AbstractMatrix)
    Base.require_one_based_indexing(Y, X, Z)
    n, m = size(X,2), size(Z,2)
    N = length(Y)
    N == size(X,1) == size(Z,1) || throw(DimensionMismatch())
    b = Matrix{typeof(zero(eltype(Y)) * zero(eltype(X)) * zero(eltype(Z)))}(undef, m,n)
    @inbounds for k = 1:n, ℓ = 1:m
        b_ℓk = zero(eltype(Y))
        @simd for i = 1:N
            b_ℓk += Y[i] * conj(X[i,k] * Z[i,ℓ])
        end
        b[ℓ,k] = b_ℓk 
    end
    return b
end

Solving the least-square problem is equivalent (modulo roundoff errors) to solving the normal equations L^T [L [b]] = L^T [y]. So, you just need to pass these linear operators to an iterative least-square algorithm (defined implicitly from the mapping functions with e.g. LinearMaps.jl), with reshape and vec to convert matrices to/from vectors.

For example, using IterativeSolvers.jl and some random data:

using IterativeSolvers, LinearMaps

m, n, N = 4, 3, 100
X, Z, Y = rand(N, n), rand(N, m), rand(N)

op = FunctionMap{Float64,false}(bvec -> L(reshape(bvec,m,n), X, Z), Y -> vec(Lᵀ(Y, X, Z)), N, m*n)
b = reshape(lsqr(op, Y), m, n)

(Caveat: the code above runs, but I haven’t tested it carefully.) Of course, you should put the above code into a function if you want to run it on a large-scale problem, as otherwise the global variables will slow things down.

6 Likes