Call function on vectors of mixed type (using `FunctionWrapper` and `Union`s)

Hi everyone and a happy new year!

In one part of my code I have a vector v whose elements have different types.
Additionally, I have a function f with different methods for all the involved types and want to call this function on the elements of the vector, let’s say I want to compute sum(f(x) for x in v).
What is the best/fastest way to do this?

There are three things I have tried so far

  1. Just leaving v as it is as a Vector{Any}
  2. Transforming v to a union typed vector of type Union{unique(typeof.(v))...}
  3. Using the FunctionWrapper.jl package (on either the normal or the union typed vector which doesn’t seem to make a difference).

What I observe from these three cases is that the implementation with the union type vector is the fastest for up to three different types and computation time jumps by a factor 1000 for four different types (I guess this has something to do with the union type splitting limit I read about elsewhere).
The time used by the implementation using FunctionWrapper, on the other hand, is independent of the number of different types but is a factor of 6 slower than the union type implementation for 1–3 different types (making it by far the fastest for more different types).

So I guess in the in the end I have three questions:

  1. Most importantly, is there a way to further improve performance in such cases, especially for the case of many different types (an example code is appended below)?
  2. Since I like to understand the code I write, could someone maybe explain to me what FunctionWrapper actually does?
  3. I would also like to understand why I observe what I observe. Why is there this sharp increase in computation time for the Union type vector after more than three different types and why is my FunctionWrapper implementation for only one type slower than the “naive” implementation?

I am grateful for answers to any of these questions!

Here is some example code I used to benchmark this

using FunctionWrappers: FunctionWrapper
using BenchmarkTools

for i in 1:100
    str = """
    struct X$i 
        x::Int
    end"""
    include_string(Main, str)
end

for i in 1:100
    y = rand()
    string = """
    g(x::X$i, z::Int) = (x.x + $y) * z
    """
    include_string(Main, string)
end

function make_random_vector(len, nTypes)
    vals = rand(1:100, len)
    types = rand(1:nTypes, len)
    str(i) = "X$(types[i])($vals[$i])"
    return [eval(Meta.parse(str(i))) for i in eachindex(vals)] 
end


function benchmark_trial(len, maxNTypes)
    vecTimes = []
    unionTimes = []
    wrappedTimes = []
    wrappedUnionTimes = []
    for i in 1:maxNTypes
        vect = make_random_vector(len, i)
        unionType = Union{unique(typeof.(vect))...}
        unionVec::Vector{unionType} = Vector{unionType}(vect)
        wrapped =  [FunctionWrapper{Float64, Tuple{Int}}(y->g(x,y)) for x in vect]
        wrappedUnion =  [FunctionWrapper{Float64, Tuple{Int}}(y->g(x,y)) for x in unionVec]
        
        push!(vecTimes, mean(@benchmark sum(g(x,5) for x in $vect)))
        push!(unionTimes, mean(@benchmark sum(g(x,5) for x in $unionVec)))
        push!(wrappedTimes, mean(@benchmark sum(h(5) for h in $wrapped)))
        push!(wrappedUnionTimes, mean(@benchmark sum(h(5) for h in $wrappedUnion)))
    end
    return vecTimes, unionTimes, wrappedTimes, wrappedUnionTimes
end

benchmarkdata = benchmark_trial(100, 20)

using Plots
benchmarkdata[1][1].time
plot([[benchmarkdata[j][i].time for i in 1:20] for j in 1:4], 
    yscale=:log10, 
    label = ["Vector" "UnionVector" "Wrapped" "Wrapped Union"],
    xlabel = "Number of Types",
    ylabel = "time_ns"
    )

A FunctionWrapper is a C function pointer I believe. Essentially each element of wrapped stores the location in memory of the function you want to execute, so when you call h(5) you first go to the specified location in memory and then execute the h. This is why it is slow even with one type, a normal function doesn’t need to do this.

Union is probably slow after three types as for three or less types h(5) will be equivalent to

if 5 isa X1
   h(5)
elseif 5 isa X2
    h(5)
else
   h(5)
end

This is fast as when h is called in the branch the type is known so there doesn’t need to be a run time lookup of which h method needs to be called. For 4 or more types this run time lookup does occur.

If you have a lot of types and cannot have multiples vectors for each type, I think the best solution would be to use FunctionWrappers.jl or Virtual.jl which both use function pointers.

You could check out TypeSortedCollections.jl, I have never used it but it seems to be perfect for your use case