Type stability with tuples is hard

The fundamental issue in your code is that you cannot use Julia’s multidimensional arrays machinery for this problem because you do not know the number of dimensions at compile time. Instead, you have to reinvent the wheel to some extent and write a multidimensional arrays code which can deal with a number of dimensions specified only at runtime. Here’s my stab at this. It definitely lose out to @lmiq in terms of output prettiness, but the implementation of the factor product should be decent. There’s still some room for improvement, but I’d say we can get into that at a later point.

using LinearAlgebra
using SparseArrays

# Like CartesianIndices, but with number of dimensions specified at runtime
struct RuntimeCartesianIndices{T <: AbstractVector{<:Integer}}
    dims::T
end

Base.IteratorSize(::Type{RuntimeCartesianIndices}) = Base.HasLength()
Base.IteratorEltype(::Type{RuntimeCartesianIndices}) = Base.HasEltype()
Base.eltype(::Type{RuntimeCartesianIndices}) = Vector{Int}
Base.length(I::RuntimeCartesianIndices) = prod(I.dims)

function Base.iterate(I::RuntimeCartesianIndices)
    if mapreduce(iszero, |, I.dims)
        return nothing
    end
    i = ones(Int,length(I.dims))
    return i,i
end

function Base.iterate(I::RuntimeCartesianIndices, i)
    for k = 1:length(I.dims)
        if i[k] < I.dims[k]
            i[k] += 1
            return i,i
        end
        i[k] = 1
    end
    return nothing
end




struct Factor{T}
    vars::Vector{Any}
    dims::Vector{Int}
    vals::Vector{T}
end

Base.eltype(::Type{Factor{T}}) where {T} = T

function factor_product(a::Factor, b::Factor)
    ##################################
    # Merge `vars` and `dims` vectors

    cvars = copy(a.vars)
    cdims = copy(a.dims)

    # Current stride in `b`
    sb = 1
    # Strides of the `a` indices in `b`
    sba = spzeros(Int, length(a.vars))
    # Strides of the tail indices in `b`
    sbt = Int[]
    # ^ These variables are needed for multiplying the correct `vals` later

    for i in 1:length(b.dims)
        j = findfirst(isequal(b.vars[i]), a.vars)
        if isnothing(j)
            push!(cvars, b.vars[i])
            push!(cdims, b.dims[i])
            push!(sbt,sb)
        else
            @assert a.dims[j] == b.dims[i]
            sba[j] = sb
        end
        sb *= b.dims[i]
    end

    ######################
    # Multiply the `vals`

    T = typeof(zero(eltype(a)) * zero(eltype(b)))
    cvals = Vector{T}(undef, prod(cdims))
    for (lt,it) = enumerate(RuntimeCartesianIndices(@view(cdims[length(a.dims)+1:end])))
        lt = prod(a.dims) * (lt-1)
        lbt = dot(sbt,it) - sum(sbt)
        for (la,ia) in enumerate(RuntimeCartesianIndices(a.dims))
            lba = dot(sba,ia) - sum(sba) + 1
            cvals[la+lt] = a.vals[la] * b.vals[lbt + lba]
        end
    end
    return Factor(cvars,cdims,cvals)
end

function test()
    a = Factor(
        Any[:A,:B],
        [3,2],
        vec([
            0.5 0.8
            0.1 0.0
            0.3 0.9
        ])
    )
    b = Factor(
        Any[:B,:C],
        [2,2],
        vec([
            0.5 0.7
            0.1 0.2
        ])
    )
    c = factor_product(a,b)

    # Produces exactly the number in your example
    println(vec(permutedims(reshape(c.vals,c.dims...), (3,2,1))))
end
3 Likes