Type Instability in performance critical function

I have a function that is critical in my code because it is called many times. I am trying to optimize it.

The function takes in an N-dimensional factor and a tuple of variables. A factor consists of an N-dimensional array and a tuple of integers that can be thought of as the “names” of each dimension. The objective of the function is to “sum” the elements of the dimensions specified by the tuple argument vars and wrap the result in a new Factor after having dropped the summed out dims. Note: the number of values along a dimension tends to be very small (in fact, just two since the variables are binary in most cases).

Here is my current definition and a small example that calls it:

using BenchmarkTools

mutable struct FactorK{T, N}
  vars::NTuple{N,Int64}
  vals::Array
end

function factor_marginalization(A::FactorK{T}, vars::NTuple{N,Int64}) where {T,N}
  dims = indexin(vars, collect(A.vars)) # map vars to dims
  r_size = ntuple(d->d in dims ? 1 : size(A.vals,d), length(A.vars)) # assign 1 to summed out dims
  ret_vars = filter(v -> v ∉ vars, A.vars)
  r_vals = similar(A.vals, r_size)
  ret_vals = sum!(r_vals, A.vals) |> x -> dropdims(x, dims=Tuple(dims))
  FactorK{eltype(A.vals),length(ret_vars)}(ret_vars, ret_vals)
end
factor_marginalization(A::FactorK, vars::Int) = factor_marginalization(A, [vars])

input = 1:8 |> x -> reshape(x, 2,2,2) |> float |> collect
A_vars = (1,2,3)
A_card = size(input)
A = FactorK{Float64,length(A_vars)}(A_vars, input)

vars = (1,2)

# factor_marginalization(A, vars)
@btime factor_marginalization(A, vars)
@code_warntype factor_marginalization(A, vars)

This implementation takes 1.226 μs (20 allocations: 1.23 KiB) according to @btime in my computer.

As you can see from the output of @code_warntype, my function isn’t type stable. I noticed that the first problem is the filter that is applied to a tuple that calculates ret_vars. It makes sense that the resulting type cannot be inferred since the number of elements depends on the values of the operands. My question is: are there any tricks I can apply to make this function type stable? If you have any suggestions of changes I could apply to my FactorK definition to make this function run faster, I would also be very interested to hear them.

Function barriers can help. For example this accelerates a little bit (10%) the code. Maybe a more compromised way to separate the type inference from the computations can help more:

function barrier(r_vals,ret_vars,Avals,dims)
  ret_vals = sum!(r_vals, Avals) |> x -> dropdims(x, dims=Tuple(dims))
  FactorK{eltype(Avals),length(ret_vars)}(ret_vars, ret_vals)
end

function factor_marginalization2(A::FactorK{T}, vars::NTuple{N,Int64}) where {T,N}
  dims = indexin(vars, collect(A.vars)) # map vars to dims
  r_size = ntuple(d->d in dims ? 1 : size(A.vals,d), length(A.vars)) # assign 1 to summed out dims
  ret_vars = filter(v -> v ∉ vars, A.vars)
  r_vals = similar(A.vals, r_size)
  barrier(r_vals,ret_vars,A.vals,dims)
end

2 Likes

Thanks a lot @lmiq. I’ve read a lot about function barriers and for some reason, it didn’t occur to me to use it here. I also see the performance gain you mention.

From the output of @code_warntype, it appears the problems all start from getting the property A.vals since that field is typed as ::Array (which is abstract). Since you are not actually mutating A (or the type of A) in the FactorK, but rather wrapping the result of the computation in a new struct, it might make more sense to parameterize FactorK with the dimensionality of vals. This should stabilize the whole thing.

2 Likes

@tomerarnon That makes a lot of sense. I’m going to try your suggestions right now.

By the way, the T parameter isn’t used anywhere in FactorK. Did you mean to parameterize vals originally as ::Array{T}?

1 Like

@tomerarnon ahh, I had this other implementation from which I started from. I forgot to remove the T. This older implementation is more in line with what you suggest above:

using BenchmarkTools

mutable struct FactorF{V,C,T} # variables, cardinality, element type
  vals::T
end

import Base: eltype, ndims, size
getvars(::FactorF{V}) where {V} = V
#eltype(::FactorF{V,C,T}) where {V,C,T} = T.parameters[1]
eltype(::FactorF{V,C,T}) where {V,C,T} = T
ndims(::FactorF{V}) where {V} = length(V)
size(::FactorF{V,C}, d::Int64) where {V,C} = C[d]

function marginalize(A::FactorF, vars::Tuple)
  drop_dims = indexin(vars, collect(getvars(A))) |> Tuple # map V to dims
  r_card = ntuple(d -> d in drop_dims ? 1 : size(A,d), ndims(A)) # assign 1 to summed out dims
  B_card = filter(s -> s != 1, r_card)
  B_vars = filter(v -> v ∉ vars, getvars(A))
  r_vals = similar(A.vals, r_card)
  B_vals = sum!(r_vals, A.vals) |> x -> dropdims(x, dims=drop_dims)
  return FactorF{B_vars, B_card, Array{Float64,length(B_vars)}}(B_vals)
end

input = 1:8 |> x -> reshape(x, 2,2,2) |> float |> collect
A_vars = (1,2,3)
A_card = size(input)
A = FactorF{A_vars, A_card, Array{Float64,length(A_vars)}}(input)

vars = (1,2)
# @btime B = marginalize(A, vars)
# B = marginalize(A, vars) |> println
@code_warntype marginalize(A, vars)

Here I’m storing the variables, the cardinality (size of array) and the element type in the type info of my struct. But as you can see, It also suffers from type instability. And the execution time is larger: 1.350 μs (20 allocations: 1.22 KiB)

1 Like

Straightforward removal of abstract types has speed improvement

using BenchmarkTools

mutable struct FactorK{T, N}
  vars::NTuple{N,Int64}
  vals::T
end

function factor_marginalization(A::FactorK{T}, vars::NTuple{N,Int64}) where {T,N}
  dims = indexin(vars, collect(A.vars)) # map vars to dims
  r_size = ntuple(d->d in dims ? 1 : size(A.vals,d), length(A.vars)) # assign 1 to summed out dims
  ret_vars = filter(v -> v ∉ vars, A.vars)
  r_vals = similar(A.vals, r_size)
  ret_vals = sum!(r_vals, A.vals) |> x -> dropdims(x, dims=Tuple(dims))
  FactorK(ret_vars, ret_vals)
end
factor_marginalization(A::FactorK, vars::Int) = factor_marginalization(A, [vars])

input = 1:8 |> x -> reshape(x, 2,2,2) |> float |> collect
A_vars = (1,2,3)
A_card = size(input)
A = FactorK{Float64,length(A_vars)}(A_vars, input)

julia> @btime factor_marginalization($A, $vars)
  936.136 ns (18 allocations: 1.19 KiB)

A bit of “boiling down” reveals that this is equivalent as the original and performs about 10% faster:

function factor_marginalization4(Avals, Avars, vars)
    dims_to_drop = filter(in(vars), Avars)
    dims_left = filter(!in(vars), Avars)
    ret_vals = dropdims(sum(Avals, dims = (dims_to_drop)), dims = (dims_to_drop))
    FactorK{eltype(Avals), length(dims_left)}(dims_left, ret_vals)
end 

factor_marginalization4(A, vars) = factor_marginalization4(A.vals, A.vars, vars)

It is still type unstable due to the tuples. A function barrier helps a bit there, making this slightly faster

function factor_marginalization5(Avals, Avars, vars)
    dims_to_drop = filter(in(vars), Avars)
    dims_left = filter(!in(vars), Avars)
    _factor_marginalization5(Avals, dims_to_drop, dims_left)
end
 
factor_marginalization5(A, vars) = factor_marginalization5(A.vals, A.vars, vars)

function _factor_marginalization5(Avals, dims_to_drop, dims_left)
    ret_vals = dropdims(sum(Avals, dims = (dims_to_drop)), dims = (dims_to_drop))
    FactorK{eltype(Avals), length(dims_left)}(dims_left, ret_vals)
end
julia> @btime factor_marginalization(A, vars)
  1.532 μs (20 allocations: 1.33 KiB)
FactorK{Float64,1}((3,), [10.0, 26.0])

julia> @btime factor_marginalization4(A, vars)
  1.308 μs (20 allocations: 688 bytes)
FactorK{Float64,1}((3,), [10.0, 26.0])

julia> @btime factor_marginalization5(A, vars)
  744.447 ns (16 allocations: 560 bytes)
FactorK{Float64,1}((3,), [10.0, 26.0])

I wonder whether using arrays rather than tuples would help in this case.

Note that I split A into its parts with a function barrier as a way to deal with the type instability. If A.vals were strictly typed you wouldn’t need this in either case.

@tomerarnon thank you very much. I’m going to explore with arrays instead of tuples to see if it helps.

That collect in indexin was bothering me. If you substitute that for a non-allocating function that gets faster:

Code

function my_indexin(x,y)
  indxs = Int[]
  for (i, xval) in pairs(x)
    for (j, yval) in pairs(y)
      if yval == xval
        push!(indxs,j)
        break
      end
    end
  end
  indxs
end

function barrier(r_vals,ret_vars,Avals,dims)
  ret_vals = sum!(r_vals, Avals) |> x -> dropdims(x, dims=Tuple(dims))
  FactorK{eltype(Avals),length(ret_vars)}(ret_vars, ret_vals)
end

function factor_marginalization3(A::FactorK{T}, vars::NTuple{N,Int64}) where {T,N}
  dims = my_indexin(vars,A.vars)
  r_size = ntuple(d->d in dims ? 1 : size(A.vals,d), length(A.vars)) # assign 1 to summed out dims
  ret_vars = filter(v -> v ∉ vars, A.vars)
  r_vals = similar(A.vals, r_size)
  barrier(r_vals,ret_vars,A.vals,dims)
end


julia> @btime factor_marginalization($A, $vars) # original
  947.938 ns (20 allocations: 1.33 KiB)
FactorK{Float64,1}((3,), [10.0, 26.0])

julia> @btime factor_marginalization3($A, $vars)
  608.273 ns (15 allocations: 624 bytes)
FactorK{Float64,1}((3,), [10.0, 26.0])

Ahh thanks. I see that @tomerarnon also got rid of it. Would be nice if the Julia help was more explicit about which functions allocate memory on the heap and which don’t. I have no idea how you came to the conclusion that that function was bothering.

FYI arrays seem to be slower after all

julia> @btime factor_marginalization5($(A.vals), (1,2,3), (1,2))
  667.293 ns (14 allocations: 496 bytes)
FactorK{Float64,1}((3,), [10.0, 26.0])

julia> @btime factor_marginalization5($(A.vals), [1,2,3], [1,2])
  1.662 μs (22 allocations: 1.16 KiB)
FactorK{Float64,1}((3,), [10.0, 26.0])

The code was indexin(vars, collect(A.vars)) meaning that you are creating an array from a tuple with collect just to do a search in it afterwards. That is only because indexin does not accept a tuple as a parameter (could it? maybe it could, I don’t know). But you certainly can iterate over the elements of a tuple, and that is what I did.

Right, indexin does not take in tuples. Gonna think about it twice before I use collect for that purpose in the future. Thanks.

Could you show me the approach you used to get rid of the collect?

I thought that dims = filter(in(vars), A.vars) was a replacement for dims = indexin(vars, collect(A.vars)) but it isn’t, they produce different results.

It was there, hidden in the “> Code”. Please check if it is correct. I think it is.

One note: for small array sizes this function is faster than the builtin indexin. But it is slower for larger arrays. Also the builtin returns nothing for elements not found, while this one simply does not push the element into the array.

I completely missed it. Thanks a lot.