Type-stability with a vector of FunctionWrappers

I have the following code that I want to use for storing a vector of functions:

using FunctionWrappers
import FunctionWrappers: FunctionWrapper

abstract type AbstractF end
struct F{P,T,U} <: AbstractF
    p::P
    f::FunctionWrapper{U,Tuple{T,T,T,U,P}}
end
struct G{P,T,U} <: AbstractF
    p::P
    f::FunctionWrapper{U,Tuple{T,T,T,U,P}}
end
(F::AbstractF)(a, b, c, d) = F.f(a, b, c, d, F.p)

When I use a vector of functions of say type F, this is all fine:

function evaluate_list_of_wrappers(wrappers)
    nums = zeros(length(wrappers))
    for (i, f) in pairs(wrappers)
        a, b, c, d = rand(4)
        nums[i] = f(a, b, c, d)
    end
    return nums
end

f1 = (a, b, c, d, p) -> a * b
f2 = (a, b, c, d, p) -> c * p[1]
g1 = (a, b, c, d, p) -> a + b + c
g2 = (a, b, c, d, p) -> a - b + c
p = (1.0, 5)
F1 = F{typeof(p),Float64,Float64}(p, f1)
F2 = F{typeof(p),Float64,Float64}(p, f2)
G1 = G{typeof(p),Float64,Float64}(p, g1)
G2 = G{typeof(p),Float64,Float64}(p, g2)
@code_warntype evaluate_list_of_wrappers((F1, F2)) # Fine 
julia> @code_warntype evaluate_list_of_wrappers((F1, F2)) # Fine 
MethodInstance for evaluate_list_of_wrappers(::Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}})
  from evaluate_list_of_wrappers(wrappers) in Main at Untitled-1:19
Arguments
  #self#::Core.Const(evaluate_list_of_wrappers)
  wrappers::Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}}
Locals
  @_3::Union{Nothing, Tuple{Pair{Int64, F{Tuple{Float64, Int64}, Float64, Float64}}, Int64}}
  nums::Vector{Float64}
  @_5::Int64
  @_6::Int64
  f::F{Tuple{Float64, Int64}, Float64, Float64}
  i::Int64
  d::Float64
  c::Float64
  b::Float64
  a::Float64
Body::Vector{Float64}
1 ─ %1  = Main.length(wrappers)::Core.Const(2)
β”‚         (nums = Main.zeros(%1))
β”‚   %3  = Main.pairs(wrappers)::Core.PartialStruct(Base.Pairs{Int64, F{Tuple{Float64, Int64}, Float64, Float64}, Base.OneTo{Int64}, Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}}}, Any[Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}}, Core.Const(Base.OneTo(2))])
β”‚         (@_3 = Base.iterate(%3))
β”‚   %5  = (@_3::Core.PartialStruct(Tuple{Pair{Int64, F{Tuple{Float64, Int64}, Float64, Float64}}, Int64}, Any[Core.PartialStruct(Pair{Int64, F{Tuple{Float64, Int64}, Float64, Float64}}, Any[Core.Const(1), F{Tuple{Float64, Int64}, Float64, Float64}]), Core.Const(1)]) === nothing)::Core.Const(false)
β”‚   %6  = Base.not_int(%5)::Core.Const(true)
└──       goto #4 if not %6
2 β”„ %8  = @_3::Tuple{Pair{Int64, F{Tuple{Float64, Int64}, Float64, Float64}}, Int64}
β”‚   %9  = Core.getfield(%8, 1)::Pair{Int64, F{Tuple{Float64, Int64}, Float64, Float64}}
β”‚   %10 = Base.indexed_iterate(%9, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
β”‚         (i = Core.getfield(%10, 1))
β”‚         (@_6 = Core.getfield(%10, 2))
β”‚   %13 = Base.indexed_iterate(%9, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, Int64}, Any[F{Tuple{Float64, Int64}, Float64, Float64}, Core.Const(3)])  
β”‚         (f = Core.getfield(%13, 1))
β”‚   %15 = Core.getfield(%8, 2)::Int64
β”‚   %16 = Main.rand(4)::Vector{Float64}
β”‚   %17 = Base.indexed_iterate(%16, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
β”‚         (a = Core.getfield(%17, 1))
β”‚         (@_5 = Core.getfield(%17, 2))
β”‚   %20 = Base.indexed_iterate(%16, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(3)])
β”‚         (b = Core.getfield(%20, 1))
β”‚         (@_5 = Core.getfield(%20, 2))
β”‚   %23 = Base.indexed_iterate(%16, 3, @_5::Core.Const(3))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(4)])
β”‚         (c = Core.getfield(%23, 1))
β”‚         (@_5 = Core.getfield(%23, 2))
β”‚   %26 = Base.indexed_iterate(%16, 4, @_5::Core.Const(4))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(5)])
β”‚         (d = Core.getfield(%26, 1))
β”‚   %28 = (f)(a, b, c, d)::Float64
β”‚         Base.setindex!(nums, %28, i)
β”‚         (@_3 = Base.iterate(%3, %15))
β”‚   %31 = (@_3 === nothing)::Bool
β”‚   %32 = Base.not_int(%31)::Bool
└──       goto #4 if not %32
3 ─       goto #2
4 β”„       return nums

But if I now include different types, I get an issue:

julia> @code_warntype evaluate_list_of_wrappers((F1, F2, G1, G2)) # ?
MethodInstance for evaluate_list_of_wrappers(::Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}})
  from evaluate_list_of_wrappers(wrappers) in Main at Untitled-1:19
Arguments
  #self#::Core.Const(evaluate_list_of_wrappers)
  wrappers::Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}}        
Locals
  @_3::Union{Nothing, Tuple{Pair{Int64, AbstractF}, Int64}}
  nums::Vector{Float64}
  @_5::Int64
  @_6::Int64
  f::AbstractF
  i::Int64
  d::Float64
  c::Float64
  b::Float64
  a::Float64
Body::Vector{Float64}
1 ─ %1  = Main.length(wrappers)::Core.Const(4)
β”‚         (nums = Main.zeros(%1))
β”‚   %3  = Main.pairs(wrappers)::Core.PartialStruct(Base.Pairs{Int64, AbstractF, Base.OneTo{Int64}, Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}}}, Any[Tuple{F{Tuple{Float64, Int64}, Float64, Float64}, F{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}, G{Tuple{Float64, Int64}, Float64, Float64}}, Core.Const(Base.OneTo(4))])
β”‚         (@_3 = Base.iterate(%3))
β”‚   %5  = (@_3::Core.PartialStruct(Tuple{Pair{Int64, AbstractF}, Int64}, Any[Core.PartialStruct(Pair{Int64, AbstractF}, Any[Core.Const(1), F{Tuple{Float64, Int64}, Float64, Float64}]), Core.Const(1)]) 
=== nothing)::Core.Const(false)
β”‚   %6  = Base.not_int(%5)::Core.Const(true)
└──       goto #4 if not %6
2 β”„ %8  = @_3::Tuple{Pair{Int64, AbstractF}, Int64}
β”‚   %9  = Core.getfield(%8, 1)::Pair{Int64, AbstractF}
β”‚   %10 = Base.indexed_iterate(%9, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
β”‚         (i = Core.getfield(%10, 1))
β”‚         (@_6 = Core.getfield(%10, 2))
β”‚   %13 = Base.indexed_iterate(%9, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{AbstractF, Int64}, Any[AbstractF, Core.Const(3)])
β”‚         (f = Core.getfield(%13, 1))
β”‚   %15 = Core.getfield(%8, 2)::Int64
β”‚   %16 = Main.rand(4)::Vector{Float64}
β”‚   %17 = Base.indexed_iterate(%16, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
β”‚         (a = Core.getfield(%17, 1))
β”‚         (@_5 = Core.getfield(%17, 2))
β”‚   %20 = Base.indexed_iterate(%16, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(3)])
β”‚         (b = Core.getfield(%20, 1))
β”‚         (@_5 = Core.getfield(%20, 2))
β”‚   %23 = Base.indexed_iterate(%16, 3, @_5::Core.Const(3))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(4)])
β”‚         (c = Core.getfield(%23, 1))
β”‚         (@_5 = Core.getfield(%23, 2))
β”‚   %26 = Base.indexed_iterate(%16, 4, @_5::Core.Const(4))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(5)])
β”‚         (d = Core.getfield(%26, 1))
β”‚   %28 = (f)(a, b, c, d)::Any
β”‚         Base.setindex!(nums, %28, i)
β”‚         (@_3 = Base.iterate(%3, %15))
β”‚   %31 = (@_3 === nothing)::Bool
β”‚   %32 = Base.not_int(%31)::Bool
└──       goto #4 if not %32
3 ─       goto #2
4 β”„       return nums

Namely, I now have f::AbstractF and %28 = (f)(a, b, c, d)::Any. Is there any way to get around this with the current setup?

Use FunctionWrappersWrappers.jl.

The answer is to recurse a tuple instead to keep the type information.

1 Like

Thanks @ChrisRackauckas. I’ve heard of this package before but couldn’t really make sense of it. Would you be able to explain how to apply FunctionWrappersWrappers.jl here please? I’m having a hard time looking through the code to understand how to actually use it / what it’s doing.

Your vector of FunctionWrappers concept in a type-stable manner is FunctionWrappersWrappers.jl’s FunctionWrappersWrapper

1 Like

I’ve had FunctionWrappersWrappers.jl working now, but this still leads to instabilities:

using FunctionWrappersWrappers
function wrap_functions(functions, parameters::P; float_type::Type{T}=Float64, u_type::Type{U}=Float64) where {T,U,P}
    wrapped_functions = ntuple(i -> FunctionWrappersWrapper(functions[i], (Tuple{T,T,T,U,typeof(parameters[i])},), (U,)), length(parameters))
    return wrapped_functions
end
f1 = (x, y, t, u, p) -> x*y*t
f2 = (x, y, t, u, p) -> x*y*t + p[1]
f3 = (x, y, t, u, p) -> x*y*t + p
params = (nothing, [1.0], 0.5)
wrapped = wrap_functions((f1,f2,f3), params)
function vector_of_function_test(wrappers, parameters)
    nums = zeros(length(wrappers))
    for i in eachindex(wrappers)
        x, y, t, u = rand(4)
        nums[i] = wrappers[i](x, y, t, u, parameters[i])
    end
    return nums
end
@code_warntype vector_of_function_test(wrapped, params)
MethodInstance for vector_of_function_test(::Tuple{FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Nothing}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Vector{Float64}}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, NTuple{5, Float64}}}, false}}, ::Tuple{Nothing, Vector{Float64}, Float64})
  from vector_of_function_test(wrappers, parameters) in Main at c:\Users\licer\using FunctionWrappers.jl:196
Arguments
  #self#::Core.Const(vector_of_function_test)
  wrappers::Tuple{FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Nothing}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Vector{Float64}}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, NTuple{5, Float64}}}, false}}       
  parameters::Tuple{Nothing, Vector{Float64}, Float64}
Locals
  @_4::Union{Nothing, Tuple{Int64, Int64}}
  nums::Vector{Float64}
  @_6::Int64
  i::Int64
  u::Float64
  t::Float64
  y::Float64
  x::Float64
Body::Vector{Float64}
1 ─ %1  = Main.length(wrappers)::Core.Const(3)
β”‚         (nums = Main.zeros(%1))
β”‚   %3  = Main.eachindex(wrappers)::Core.Const(Base.OneTo(3))
β”‚         (@_4 = Base.iterate(%3))
β”‚   %5  = (@_4::Core.Const((1, 1)) === nothing)::Core.Const(false)
β”‚   %6  = Base.not_int(%5)::Core.Const(true)
└──       goto #4 if not %6
2 β”„ %8  = @_4::Tuple{Int64, Int64}
β”‚         (i = Core.getfield(%8, 1))
β”‚   %10 = Core.getfield(%8, 2)::Int64
β”‚   %11 = Main.rand(4)::Vector{Float64}
β”‚   %12 = Base.indexed_iterate(%11, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
β”‚         (x = Core.getfield(%12, 1))
β”‚         (@_6 = Core.getfield(%12, 2))
β”‚   %15 = Base.indexed_iterate(%11, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(3)])
β”‚         (y = Core.getfield(%15, 1))
β”‚         (@_6 = Core.getfield(%15, 2))
β”‚   %18 = Base.indexed_iterate(%11, 3, @_6::Core.Const(3))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(4)])
β”‚         (t = Core.getfield(%18, 1))
β”‚         (@_6 = Core.getfield(%18, 2))
β”‚   %21 = Base.indexed_iterate(%11, 4, @_6::Core.Const(4))::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(5)])
β”‚         (u = Core.getfield(%21, 1))
β”‚   %23 = Base.getindex(wrappers, i)::Union{FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Nothing}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, NTuple{5, Float64}}}, false}, FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Float64, Tuple{Float64, Float64, Float64, Float64, Vector{Float64}}}}, false}}
β”‚   %24 = x::Float64
β”‚   %25 = y::Float64
β”‚   %26 = t::Float64
β”‚   %27 = u::Float64
β”‚   %28 = Base.getindex(parameters, i)::Union{Nothing, Float64, Vector{Float64}}
β”‚   %29 = (%23)(%24, %25, %26, %27, %28)::Any
β”‚         Base.setindex!(nums, %29, i)
β”‚         (@_4 = Base.iterate(%3, %10))
β”‚   %32 = (@_4 === nothing)::Bool
β”‚   %33 = Base.not_int(%32)::Bool
└──       goto #4 if not %33
3 ─       goto #2
4 β”„       return nums

(from %23 and %29). Is there any way around this? This is removed if all the parameters in params are of the same type, but this won’t always be the case. Note also that my test function has the functions being evaluated in order, but it could be in arbitrary orders later, e.g. f[1], f[2], f[3], f[2], f[3], f[1], ....

I will not be offering a solution here (others have already done this) but rather analyse the problem (teaching Julia, with apologies if you do not need this :grin:)

In your first example, you pass a tuple containing two variables of the same subtype of FunctionWrapper. A loop over this tuple is typestable (no variable inside the loop change type from iteration to iteration). In particular, thanks to FunctionWrappers, the compiler can determine the type of the functions output - even though you deal with several functions (FunctionWrappers.jl is brilliant).

You break that in your second example where you pass a β€œheterogenous” tuple. Then it’s about the level of smartness we can expect of the compiler. As humans, we see that F and G wrappers specify that all function in your example have return values of the same type - which can lead to expect type stability.

But the real question is whether the computer has the smarts to prove type stability. I reckon the heuristic it uses goes like this:

if looping over container with uniform concrete type
     generate code handling same type of all variables at each iteration (typestable)
else
     generate code assuming new type of all variables at each iteration (type instability)
end
2 Likes

You can’t loop and index them like that You’d have to

    map(wrappers,parameters) do f,p
        x, y, t, u = rand(4)
        f(x, y, t, u, p)
    end
1 Like

I’d never used map before, that looks good. Doesn’t seem to be any type instability in @code_warntype, I would’ve thought there might still be in parameters[i].

Do you know if this idea can extend if I need to obtain the function from the loop variable? e.g.

function vector_map_to_functions(wrappers, parameters, nodes, node_map)
    nums = zeros(length(node_map))
    for j in nodes 
        i = node_map[j]
        x, y, t, u = rand(4)
        f = wrappers[i]
        nums[j] = f(x, y, t, u, parameters[i])
    end
    return nums 
end
f1 = (x, y, t, u, p) -> x*y*t
f2 = (x, y, t, u, p) -> x*y*t + p[1]
f3 = (x, y, t, u, p) -> x*y*t + p
params = (nothing, [1.0], 0.5)
wrapped = wrap_functions((f1,f2,f3), params)
nodes = rand(1:100, 50)
node_map = rand(1:3, 50)
@code_warntype vector_map_to_functions(wrapped,params,nodes,node_map)