I have a function that is critical in my code because it is called many times. I am trying to optimize it.
The function takes in an N-dimensional factor and a tuple of variables. A factor consists of an N-dimensional array and a tuple of integers that can be thought of as the “names” of each dimension. The objective of the function is to “sum” the elements of the dimensions specified by the tuple argument vars
and wrap the result in a new Factor
after having dropped the summed out dims. Note: the number of values along a dimension tends to be very small (in fact, just two since the variables are binary in most cases).
Here is my current definition and a small example that calls it:
using BenchmarkTools
mutable struct FactorK{T, N}
vars::NTuple{N,Int64}
vals::Array
end
function factor_marginalization(A::FactorK{T}, vars::NTuple{N,Int64}) where {T,N}
dims = indexin(vars, collect(A.vars)) # map vars to dims
r_size = ntuple(d->d in dims ? 1 : size(A.vals,d), length(A.vars)) # assign 1 to summed out dims
ret_vars = filter(v -> v ∉ vars, A.vars)
r_vals = similar(A.vals, r_size)
ret_vals = sum!(r_vals, A.vals) |> x -> dropdims(x, dims=Tuple(dims))
FactorK{eltype(A.vals),length(ret_vars)}(ret_vars, ret_vals)
end
factor_marginalization(A::FactorK, vars::Int) = factor_marginalization(A, [vars])
input = 1:8 |> x -> reshape(x, 2,2,2) |> float |> collect
A_vars = (1,2,3)
A_card = size(input)
A = FactorK{Float64,length(A_vars)}(A_vars, input)
vars = (1,2)
# factor_marginalization(A, vars)
@btime factor_marginalization(A, vars)
@code_warntype factor_marginalization(A, vars)
This implementation takes 1.226 μs (20 allocations: 1.23 KiB)
according to @btime
in my computer.
As you can see from the output of @code_warntype
, my function isn’t type stable. I noticed that the first problem is the filter
that is applied to a tuple that calculates ret_vars
. It makes sense that the resulting type cannot be inferred since the number of elements depends on the values of the operands. My question is: are there any tricks I can apply to make this function type stable? If you have any suggestions of changes I could apply to my FactorK
definition to make this function run faster, I would also be very interested to hear them.