I think this is about halfway to the solution you want:
using StaticArrays
# compute y = (I₁ ⊗ a₂ᵗ ⊗ I₃) * θ
function kronmul(::Val{N}, a₂::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
y = zeros(M*N)
# crawls along θ with unit stride
for i in 1:N, k in 1:K, j in 1:M
y[M*(i-1) + j] += a₂[k] * θ[M*K*(i-1) + M*(k-1) + j]
end
return SVector{M*N}(y)
end
Here’s an example:
using LinearAlgebra
using BenchmarkTools
N, K, M = 3, 4, 5
I₁ = I(N)
a₂ = SVector{K}(rand(K))
I₃ = I(M)
θ = SVector{K*M*N}(rand(K*M*N))
vN = Val(N)
vM = Val(M)
# correctness
expected = kron(I₁, a₂', I₃)*θ;
observed = kronmul(vN, a₂, vM, θ);
expected ≈ observed # true
# benchmark
@benchmark kronmul($vN, $a₂, $vM, $θ)
# BenchmarkTools.Trial:
# memory estimate: 208 bytes
# allocs estimate: 1
# --------------
# minimum time: 106.866 ns (0.00% GC)
# median time: 118.005 ns (0.00% GC)
# mean time: 126.083 ns (5.80% GC)
# maximum time: 2.194 μs (92.02% GC)
# --------------
# samples: 10000
# evals/sample: 947
If that looks good but you want to be able to do A*θ
(A
representing the right Kronecker product), you could (ab)use LinearMaps.jl to implement A_mul_B!
and/or At_mul_B!
. The following is hacky, incomplete implementation:
import Base: *, size
import LinearMaps: LinearMap
# represents (I₁ ⊗ a₂ᵗ ⊗ I₃) which is M*N by K*M*N
struct KronOperator{N,K,M,T} <: LinearMap{T}
a::SVector{K,T}
function KronOperator{N,M}(a::SVector{K,T}) where {N,K,M,T}
new{N,K,M,T}(a)
end
end
Base.size(::KronOperator{N,K,M}) where {K,M,N} = (M*N, K*M*N)
# not needed, I only did this to avoid writing a 'mutating' version of kronmul
function Base.:(*)(A::KronOperator{N,K,M}, x::AbstractVector) where {N,K,M}
length(x) == size(A,2) || throw(DimensionMismatch())
y = kronmul(Val(N), A.a, Val(M), x)
return y
end
It seems to have similar performance to calling my kronmul
from before:
N, K, M = 3, 4, 5
I₁ = I(N)
a₂ = SVector{K}(rand(K))
I₃ = I(M)
θ = SVector{K*M*N}(rand(K*M*N))
A = KronOperator{N,M}(a₂)
# correctness
expected = kron(I₁, a₂', I₃)*θ;
observed = A*θ;
expected ≈ observed # true
# benchmark
@benchmark *($A, $θ)
# BenchmarkTools.Trial:
# memory estimate: 208 bytes
# allocs estimate: 1
# --------------
# minimum time: 105.308 ns (0.00% GC)
# median time: 120.011 ns (0.00% GC)
# mean time: 130.112 ns (5.67% GC)
# maximum time: 2.916 μs (94.41% GC)
# --------------
# samples: 10000
# evals/sample: 946
EDIT: I missed @mcabbott’s solution which appears to implement the same algorithm. However, there is a slight difference in performance for some reason
@benchmark f($Val(N), $a₂, $Val(M), $θ)
BenchmarkTools.Trial:
memory estimate: 672 bytes
allocs estimate: 3
--------------
minimum time: 6.812 μs (0.00% GC)
median time: 7.012 μs (0.00% GC)
mean time: 7.099 μs (0.00% GC)
maximum time: 23.861 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 5