For some tensor operations I need to keep track of partial indices and call out to user supplied functions. In principle the type of all arguments can be inferred statically, but Julia doesn’t seem to do that, leading to a lot of dynamic dispatch. How can I coax Julia into inferring types?
Example
# Some user supplied function
function bar(I::NTuple{3,Int})
sum(I)
end
function foo(::Val{NDims}, f::F) where {NDims,F<:Function}
dims = ntuple(_ -> 10, NDims)
#Generate Tuples of Lists of Tuples, such that
# Is .|> getindex .|> length +
# Js .|> getindex .|> length .== NDims
# Can't really pack these two, since elements are added and removed independendently
Is = let idxs::NTuple{NDims,Int} = dims .÷ 2
ntuple(
(@inline function (i)
[idxs[1:i-1]]
end),
NDims + 1
)
end
Js = let idxs::NTuple{NDims,Int} = dims .÷ 2
ntuple(
(@inline function (i)
[idxs[i:end]]
end),
NDims + 1
)
end
acc = 0
for k in 1:NDims+1
for i in Is[k], j in Js[k]
# This function call is dynamically dispatched, while the argument type is
# is always `NTuple{NDims, Int}`
acc += f((i..., j...))
end
end
acc
end
foo(Val(3), bar)
@code_warntype foo(Val(3), bar)
Output
MethodInstance for foo(::Val{3}, ::typeof(bar))
from foo(::Val{NDims}, f::F) where {NDims, F<:Function} @ Main ~/.julia/dev/TTApproximations/play.jl:26
Static Parameters
NDims = 3
F = typeof(bar)
Arguments
#self#::Core.Const(foo)
_::Core.Const(Val{3}())
f::Core.Const(bar)
Locals
…
@_14::Union{Nothing, Tuple{Any, Int64}}
k::Int64
@_16::Union{Nothing, Tuple{Any, Int64}}
i@_17::Any
j::Any
i@_19::Any
Note that NDims
is usually small, so the loop could be unrolled. without too many issues.