The fundamental issue in your code is that you cannot use Julia’s multidimensional arrays machinery for this problem because you do not know the number of dimensions at compile time. Instead, you have to reinvent the wheel to some extent and write a multidimensional arrays code which can deal with a number of dimensions specified only at runtime. Here’s my stab at this. It definitely lose out to @lmiq in terms of output prettiness, but the implementation of the factor product should be decent. There’s still some room for improvement, but I’d say we can get into that at a later point.
using LinearAlgebra
using SparseArrays
# Like CartesianIndices, but with number of dimensions specified at runtime
struct RuntimeCartesianIndices{T <: AbstractVector{<:Integer}}
dims::T
end
Base.IteratorSize(::Type{RuntimeCartesianIndices}) = Base.HasLength()
Base.IteratorEltype(::Type{RuntimeCartesianIndices}) = Base.HasEltype()
Base.eltype(::Type{RuntimeCartesianIndices}) = Vector{Int}
Base.length(I::RuntimeCartesianIndices) = prod(I.dims)
function Base.iterate(I::RuntimeCartesianIndices)
if mapreduce(iszero, |, I.dims)
return nothing
end
i = ones(Int,length(I.dims))
return i,i
end
function Base.iterate(I::RuntimeCartesianIndices, i)
for k = 1:length(I.dims)
if i[k] < I.dims[k]
i[k] += 1
return i,i
end
i[k] = 1
end
return nothing
end
struct Factor{T}
vars::Vector{Any}
dims::Vector{Int}
vals::Vector{T}
end
Base.eltype(::Type{Factor{T}}) where {T} = T
function factor_product(a::Factor, b::Factor)
##################################
# Merge `vars` and `dims` vectors
cvars = copy(a.vars)
cdims = copy(a.dims)
# Current stride in `b`
sb = 1
# Strides of the `a` indices in `b`
sba = spzeros(Int, length(a.vars))
# Strides of the tail indices in `b`
sbt = Int[]
# ^ These variables are needed for multiplying the correct `vals` later
for i in 1:length(b.dims)
j = findfirst(isequal(b.vars[i]), a.vars)
if isnothing(j)
push!(cvars, b.vars[i])
push!(cdims, b.dims[i])
push!(sbt,sb)
else
@assert a.dims[j] == b.dims[i]
sba[j] = sb
end
sb *= b.dims[i]
end
######################
# Multiply the `vals`
T = typeof(zero(eltype(a)) * zero(eltype(b)))
cvals = Vector{T}(undef, prod(cdims))
for (lt,it) = enumerate(RuntimeCartesianIndices(@view(cdims[length(a.dims)+1:end])))
lt = prod(a.dims) * (lt-1)
lbt = dot(sbt,it) - sum(sbt)
for (la,ia) in enumerate(RuntimeCartesianIndices(a.dims))
lba = dot(sba,ia) - sum(sba) + 1
cvals[la+lt] = a.vals[la] * b.vals[lbt + lba]
end
end
return Factor(cvars,cdims,cvals)
end
function test()
a = Factor(
Any[:A,:B],
[3,2],
vec([
0.5 0.8
0.1 0.0
0.3 0.9
])
)
b = Factor(
Any[:B,:C],
[2,2],
vec([
0.5 0.7
0.1 0.2
])
)
c = factor_product(a,b)
# Produces exactly the number in your example
println(vec(permutedims(reshape(c.vals,c.dims...), (3,2,1))))
end