Avoid Dynamic Dispatch

For some tensor operations I need to keep track of partial indices and call out to user supplied functions. In principle the type of all arguments can be inferred statically, but Julia doesn’t seem to do that, leading to a lot of dynamic dispatch. How can I coax Julia into inferring types?

Example

# Some user supplied function
function bar(I::NTuple{3,Int})
    sum(I)
end
function foo(::Val{NDims}, f::F) where {NDims,F<:Function}
    dims = ntuple(_ -> 10, NDims)
    #Generate Tuples of Lists of Tuples, such that 
    # Is  .|> getindex .|> length + 
    # Js  .|> getindex .|> length .== NDims
    # Can't really pack these two, since elements are added and removed independendently
    Is = let idxs::NTuple{NDims,Int} = dims .÷ 2
        ntuple(
            (@inline function (i)
                [idxs[1:i-1]]
            end),
            NDims + 1
        )
    end

    Js = let idxs::NTuple{NDims,Int} = dims .÷ 2
        ntuple(
            (@inline function (i)
                [idxs[i:end]]
            end),
            NDims + 1
        )
    end
    acc = 0
    for k in 1:NDims+1
        for i in Is[k], j in Js[k]
            # This function call is dynamically dispatched, while the argument type is
            # is always `NTuple{NDims, Int}`
            acc += f((i..., j...))
        end
    end
    acc
end

foo(Val(3), bar)
@code_warntype foo(Val(3), bar)

Output

MethodInstance for foo(::Val{3}, ::typeof(bar))
from foo(::Val{NDims}, f::F) where {NDims, F<:Function} @ Main ~/.julia/dev/TTApproximations/play.jl:26
Static Parameters
NDims = 3
F = typeof(bar)
Arguments
#self#::Core.Const(foo)
_::Core.Const(Val{3}())
f::Core.Const(bar)
Locals

@_14::Union{Nothing, Tuple{Any, Int64}}
k::Int64
@_16::Union{Nothing, Tuple{Any, Int64}}
i@_17::Any
j::Any
i@_19::Any

Note that NDims is usually small, so the loop could be unrolled. without too many issues.

typeof(Is) = Tuple{Vector{Tuple{}}, Vector{Tuple{Int64}}, Vector{Tuple{Int64, Int64}}, Vector{Tuple{Int64, Int64, Int64}}}

given this Tuple, compiler cannot know what Is[k] will be since it’s not unrolled.

try GitHub - cstjean/Unrolled.jl: Unrolling loops at compile-time

Unrolled doesn’t seem to work. I got a solution by using @generated:

# Some user supplied function
function bar(I::NTuple{3,Int})
   sum(I)
end


dims = ntuple(_ -> 10, 3)
NDims = length(dims)
function foo(dims::NTuple{NDims,Int}, f::F) where {F<:Function,NDims}
   #Generate Tuples of Lists of Tuples, such that 
   # Is  .|> getindex .|> length + 
   # Js  .|> getindex .|> length .== NDims
   Is = let idxs::NTuple{NDims,Int} = dims .÷ 2
       ntuple(
           (@inline function (i)
               [idxs[1:i-1]]
           end),
           NDims + 1
       )
   end

   Js = let idxs::NTuple{NDims,Int} = dims .÷ 2
       ntuple(
           (@inline function (i)
               [idxs[i:end]]
           end),
           NDims + 1
       )
   end
   acc = Base.RefValue(0)
   foo_steps!(acc, Is, Js, f, dims)
   acc[]
end

@generated function foo_steps!(acc, Is, Js, f::F, dims::NTuple{NDims,Int}) where {NDims,F<:Function}
   expr = :()
   for k in 1:NDims+1
       expr = quote
           $(expr)
           foo_inner!(acc, Is, Js, f, $k)
       end
   end
   :($expr, nothing)
end

@inline function foo_inner!(acc, Is, Js, f::F, k,) where {F<:Function}
   for i in Is[k], j in Js[k]
       # Some dummy function call.
       I = (i..., j...)
       acc[] += f(I)
   end
end

But I’m wondering if there’s a nicer solution.