Tell the compiler the return type of function chosen at runtime

The following MRE function:

function fun(idx, a::Int, b::Int)
    funs = [+, -]
    return funs[idx](a, b)
end

Is not type-stable because Julia does not infer the return type of funs[idx]. Here is the output of @code_warntype fun(1, 2, 3):

MethodInstance for fun(::Int64, ::Int64, ::Int64)
  from fun(idx, a, b) @ Main 
Arguments
  #self#::Core.Const(fun)
  idx::Int64
  a::Int64
  b::Int64
Locals
  funs::Vector{Function}
Body::Any
1 ─      (funs = Base.vect(Main.:+, Main.:-))
β”‚   %2 = Base.getindex(funs, idx)::Function
β”‚   %3 = (%2)(a, b)::Any
└──      return %3

Say that I know the return type will always be Int64. Is there an annotation, or parametric type for funs that will tell the compiler that? Is there another way to help the compiler infer the return type in similar cases?

You can use FunctionWrappers.jl:

import FunctionWrappers: FunctionWrapper

function fun(idx, a::Int, b::Int)
    funs = FunctionWrapper{Int64, Tuple{Int64, Int64}}[+, -]
    return funs[idx](a, b)
end

@code_warntype fun(1, 3, 5)

Not a lot of documentation available, but I think it’s used quite extensively across the Julia ecosystem. Hopefully someone else can fill any technical details.

3 Likes

If you only have 2 or 3 functions, replacing the vector with a tuple also does the trick:

function fun(idx, a::Int, b::Int)
    funs = (+, -)
    return funs[idx](a, b)
end
julia> @code_warntype fun(1, 2, 3)
MethodInstance for fun(::Int64, ::Int64, ::Int64)
  from fun(idx, a::Int64, b::Int64) @ Main REPL[1]:1
Arguments
  #self#::Core.Const(fun)
  idx::Int64
  a::Int64
  b::Int64
Locals
  funs::Tuple{typeof(+), typeof(-)}
Body::Int64
1 ─      (funs = Core.tuple(Main.:+, Main.:-))
β”‚   %2 = Base.getindex(funs, idx)::Union{typeof(+), typeof(-)}
β”‚   %3 = (%2)(a, b)::Int64
└──      return %3
1 Like

Another simple option is to use type annotations, but note that the annotation needs to be both concrete and determined at compile time.

funs[idx](a, b)::Int
1 ─      (funs = Base.vect(Main.:+, Main.:-))
β”‚   %2 = Base.getindex(funs, idx)::Function
β”‚   %3 = (%2)(a, b)::Any
β”‚   %4 = Core.typeassert(%3, Main.Int)::Int64
└──      return %4