# Kronecker times vector with special structure

I would like to evaluate

(I_1 \otimes a_2^\top \otimes I_3) \cdot \theta

where

1. I_1 and I_3 are identity matrices with size n \times n and m \times m,
2. a_2 is a vector of length k,
3. \theta is a vector of length mnk.

Is there a clever way of doing this without forming the Kronecker product? Particularly with SVectors, ideally I am looking for something that would make this method fast:

function f(::Val{N}, a2::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
...
end

1 Like

Can’t you use LazyArrays.jl?

I am not aware of it doing anything special for static vectors, did I miss something?

Also, the structure above is rather special so I am hoping that I can do something clever.

It doesn’t need to do anything special:

julia> using LazyArrays, StaticArrays, FillArrays

julia> n,m,k = 3,4,5;

julia> a2,θ = @SVector(randn(k)),@SVector(randn(m*n*k));

julia> ApplyArray(*, ApplyArray(kron, Eye(n), a2', Eye(m)), θ)
0.3749523787239248
-1.170860382691954
0.9813174805425123
2.2870804492085264
-0.7093519111205916
0.8346906237589655
2.126796942875452
-2.051815270138217
-1.279417832090675
2.76420566315552
1.668282808022956
0.7445085628135768


Unless you want a StaticApplyArray type so it knows the dimensions at compile time?

Yes, I want the result to be an SVector{M * N}.

I am hoping that there is a reshuffling of θ which would reduce this to a simple matrix-vector product, which would be unrolled. Generally I have mnk \le 30.

Can you do anything from this?

julia> n,k,m = 3,4,5;
julia> θ = rand(m*k*n);
julia> A = kron( I(n), kron(a',I(m) ) );
julia> y = A*θ;

julia> z = vcat([ reshape(θ[j*k*m.+(1:m*k)],m,k)*a for j in 0:n-1 ]...);

julia> norm(y-z)
5.438959822042073e-16


Apologies if this relationship was already clear to you (especially if this is where you started…).

julia> SVector{12}(ApplyArray(*, ApplyArray(kron, Eye(n), a2', Eye(m)), θ))
12-element SArray{Tuple{12},Float64,1,12} with indices SOneTo(12):
1.2642710321481372
3.2690469214793847
6.100075941415405
-0.5696072139699527
-0.07671184867032166
0.9397326767055554
-1.7540761677701062
2.4079231817499878
0.34952059878993136
-1.1246010634555499
-0.5592296926986757
2.1998373648537384


This shouldn’t allocate…unfortunately it does but it’s just due to missing a couple overloads (e.g. getindex for a 3-term Kronecker product resorts to calling kron. A 2-term Kronecker product works though.)

You can just write it out? res_ae = Σ_bcf δ_ab v_c δ_ef θ_bcf = Σ_c v_c θ_ace where a,b=1:n, c=1:k, e,f=1:m and I called the vector v. So only one sum survives, and once you allow for kron having its conventions backwards, I think you get this:

n,m,k = 3,4,5;
a2,θ = randn(k), randn(m*n*k)
using LinearAlgebra, Einsum
res = kron(I(n), a2', I(m)) * θ

θ3 = reshape(θ, m,k,n);
res ≈ vec(@einsum r2[e,a] := θ3[e,c,a] * a2[c])

res ≈ [sum(a2[c+1] * θ[1 + m*k*a + m*c + e] for c in 0:k-1) for a in 0:n-1 for e in 0:m-1 ]


And then you could make this an ntuple, or perhaps easier a generated function:

@generated function f(::Val{N}, a2::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
vals = []
for a in 0:N-1, e in 0:M-1
terms = []
for c in 0:K-1
i = 1 + M*K*a + M*c + e
push!(terms, :(a2[$c+1] * θ[$i]))
end
push!(vals, :(+($(terms...)))) end :(SVector($(vals...)))
end

7 Likes

I think this is about halfway to the solution you want:

using StaticArrays

# compute y = (I₁ ⊗ a₂ᵗ ⊗ I₃) * θ
function kronmul(::Val{N}, a₂::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
y = zeros(M*N)

# crawls along θ with unit stride
for i in 1:N, k in 1:K, j in 1:M
y[M*(i-1) + j] += a₂[k] * θ[M*K*(i-1) + M*(k-1) + j]
end

return SVector{M*N}(y)
end


Here’s an example:

using LinearAlgebra
using BenchmarkTools

N, K, M = 3, 4, 5

I₁ = I(N)
a₂ = SVector{K}(rand(K))
I₃ = I(M)
θ = SVector{K*M*N}(rand(K*M*N))
vN = Val(N)
vM = Val(M)

# correctness
expected = kron(I₁, a₂', I₃)*θ;
observed = kronmul(vN, a₂, vM, θ);
expected ≈ observed  # true

# benchmark
@benchmark kronmul($vN,$a₂, $vM,$θ)
# BenchmarkTools.Trial:
#   memory estimate:  208 bytes
#   allocs estimate:  1
#   --------------
#   minimum time:     106.866 ns (0.00% GC)
#   median time:      118.005 ns (0.00% GC)
#   mean time:        126.083 ns (5.80% GC)
#   maximum time:     2.194 μs (92.02% GC)
#   --------------
#   samples:          10000
#   evals/sample:     947


If that looks good but you want to be able to do A*θ (A representing the right Kronecker product), you could (ab)use LinearMaps.jl to implement A_mul_B! and/or At_mul_B!. The following is hacky, incomplete implementation:

import Base: *, size
import LinearMaps: LinearMap

# represents (I₁ ⊗ a₂ᵗ ⊗ I₃) which is M*N by K*M*N
struct KronOperator{N,K,M,T} <: LinearMap{T}
a::SVector{K,T}

function KronOperator{N,M}(a::SVector{K,T}) where {N,K,M,T}
new{N,K,M,T}(a)
end
end

Base.size(::KronOperator{N,K,M}) where {K,M,N} = (M*N, K*M*N)

# not needed, I only did this to avoid writing a 'mutating' version of kronmul
function Base.:(*)(A::KronOperator{N,K,M}, x::AbstractVector) where {N,K,M}
length(x) == size(A,2) || throw(DimensionMismatch())
y = kronmul(Val(N), A.a, Val(M), x)
return y
end


It seems to have similar performance to calling my kronmul from before:

N, K, M = 3, 4, 5

I₁ = I(N)
a₂ = SVector{K}(rand(K))
I₃ = I(M)
θ = SVector{K*M*N}(rand(K*M*N))
A = KronOperator{N,M}(a₂)

# correctness
expected = kron(I₁, a₂', I₃)*θ;
observed = A*θ;
expected ≈ observed # true

# benchmark
@benchmark *($A,$θ)
# BenchmarkTools.Trial:
#   memory estimate:  208 bytes
#   allocs estimate:  1
#   --------------
#   minimum time:     105.308 ns (0.00% GC)
#   median time:      120.011 ns (0.00% GC)
#   mean time:        130.112 ns (5.67% GC)
#   maximum time:     2.916 μs (94.41% GC)
#   --------------
#   samples:          10000
#   evals/sample:     946


EDIT: I missed @mcabbott’s solution which appears to implement the same algorithm. However, there is a slight difference in performance for some reason

@benchmark f($Val(N),$a₂, $Val(M),$θ)
BenchmarkTools.Trial:
memory estimate:  672 bytes
allocs estimate:  3
--------------
minimum time:     6.812 μs (0.00% GC)
median time:      7.012 μs (0.00% GC)
mean time:        7.099 μs (0.00% GC)
maximum time:     23.861 μs (0.00% GC)
--------------
samples:          10000
evals/sample:     5

1 Like

I think this interpolates Val but not N:

julia> @btime f(Val(N), $a₂, Val(M),$θ); # not interpolating N,M
4.572 μs (3 allocations: 672 bytes)

julia> @btime f($vN,$a₂, $vM,$θ); # as used for kronmul
0.029 ns (0 allocations: 0 bytes)


The difference is that kronmul allocates an ordinary array y = zeros(M*N), while f avoids this. Replacing that with y = @MVector zeros(M*N) brings it to zero allocations:

julia> @btime kronmul($vN,$a₂, $vM,$θ);
83.992 ns (1 allocation: 208 bytes)

julia> @btime kronmulM($vN,$a₂, $vM,$θ);
32.326 ns (0 allocations: 0 bytes)


I’m not sure we should quite believe these numbers, it’s possible that something like this is more representative?

julia> @btime f($vN, a₂,$vM, θ)  setup=(θ=@SVector(rand(K*M*N)), a₂=@SVector(rand(K)));
35.083 ns (1 allocation: 128 bytes)

julia> @btime kronmul($vN, a₂,$vM, θ)  setup=(θ=@SVector(rand(K*M*N)), a₂=@SVector(rand(K)));
107.175 ns (2 allocations: 336 bytes)

julia> @btime kronmulM($vN, a₂,$vM, θ)  setup=(θ=@SVector(rand(K*M*N)), a₂=@SVector(rand(K)));
54.427 ns (1 allocation: 128 bytes)


Edit: one more MArray solution:

julia> function f_ein(::Val{N}, a2::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
θ3 = reshape(θ, M,K,N);
r2 = @MMatrix zeros(M,N)
@einsum r2[e,a] = θ3[e,c,a] * a2[c]
SVector{M*N}(r2)
end

julia> @btime f_ein($vN,$a₂, $vM,$θ)
19.785 ns (0 allocations: 0 bytes)

julia> @btime f_ein($vN, a₂,$vM, θ)  setup=(θ=@SVector(rand(K*M*N)), a₂=@SVector(rand(K)));
40.470 ns (1 allocation: 128 bytes)

3 Likes

I would like to thank everyone for the suggestions. In the end I decided to go with the generated function approach by @mcabbott (which I think has an indexing bug, but I rewrote it anyway). I also kept the einsum version in the benchmarks because I find the solution really neat.

### Complete code for benchmarking

using StaticArrays, BenchmarkTools, LinearAlgebra, Einsum

f_naive(::Val{M}, a, ::Val{N}, θ) where {M,N} = kron(I(M), permutedims(a), I(N)) * θ

@generated function f_generated(::Val{M}, a::SVector{K}, ::Val{N}, θ::SVector) where {M,K,N}
function _f(n, m)
offset = (m - 1) * K * N + n
:(+$(map(k -> :(a[$(k)] * θ[$(offset + (k - 1) * N)]), 1:K)...)) end :(SVector($(vec(_f.(1:N, permutedims(1:M)))...)))
end

function f_ein(::Val{N}, a2::SVector{K}, ::Val{M}, θ::SVector) where {N,K,M}
θ3 = reshape(θ, M,K,N);
r2 = @MMatrix zeros(M,N)
@einsum r2[e,a] = θ3[e,c,a] * a2[c]
SVector{M*N}(r2)
end

K = 3
M = 4
N = 5

@btime f_naive(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));
@btime f_generated(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));
@btime f_ein(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));


### Results on master

julia> @btime f_naive(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));
1.976 μs (7 allocations: 10.92 KiB)

julia> @btime f_generated(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));
452.690 ns (3 allocations: 704 bytes)

julia> @btime f_ein(Val($M), a, Val($N), θ) setup=(θ=@SVector(rand(K*M*N)); a=@SVector(rand(K)));
450.203 ns (3 allocations: 704 bytes)

3 Likes