Experiments in speeding up runtime dispatch for integer value specialization

I am trying to develop a package that needs to operate on small matrices, the size of which isn’t known at runtime. Nevertheless, there is a lot of performance to be gained by specializing/parameterizing particular functions on the size of these matrices, and using StaticArrays.

One way to do this is simply to use Val(n) and runtime dispatch. However, this in itself incurs a significant overhead. I can reduce this overhead significantly (possibly 10x) by using a binary search to turn the runtime dispatch into a compile time one. However, this requires the value range to be hardcoded, which isn’t ideal - it would be nice if it were general. So I also tried to write the binary search algorithm recursively. Unfortunately this wiped out most of the performance gains of the hard-coded binary search.

This raises some questions for me:

  1. Why is runtime dispatch so slow, when a simple binary search is relatively fast?
  2. Why does the recursive implementation, which should essentially be doing the same thing as the hard-coded implementation, run much slower, and incur some allocations?
  3. Given that the binary search approach seems useful, is this functionality already available anywhere currently?

Any thoughts?

MWE and benchmark results below:

using BenchmarkTools, StaticArrays

function valuedispatch1to32(fun, val)
    if val <= 16
        if val <= 8
            if val <= 4
                if val <= 2
                    if val == 2
                        return fun(Val(2))
                    else
                        return fun(Val(1))
                    end
                else
                    if val == 4
                        return fun(Val(4))
                    else
                        return fun(Val(3))
                    end
                end
            else
                if val <= 6
                    if val == 6
                        return fun(Val(6))
                    else
                        return fun(Val(5))
                    end
                else
                    if val == 8
                        return fun(Val(8))
                    else
                        return fun(Val(7))
                    end
                end
            end
        else
            if val <= 12
                if val <= 10
                    if val == 10
                        return fun(Val(10))
                    else
                        return fun(Val(11))
                    end
                else
                    if val == 12
                        return fun(Val(12))
                    else
                        return fun(Val(11))
                    end
                end
            else
                if val <= 14
                    if val == 14
                        return fun(Val(14))
                    else
                        return fun(Val(13))
                    end
                else
                    if val == 16
                        return fun(Val(16))
                    else
                        return fun(Val(15))
                    end
                end
            end
        end
    else
        if val <= 24
            if val <= 20
                if val <= 18
                    if val == 18
                        return fun(Val(18))
                    else
                        return fun(Val(17))
                    end
                else
                    if val == 20
                        return fun(Val(20))
                    else
                        return fun(Val(19))
                    end
                end
            else
                if val <= 22
                    if val == 22
                        return fun(Val(22))
                    else
                        return fun(Val(21))
                    end
                else
                    if val == 24
                        return fun(Val(24))
                    else
                        return fun(Val(23))
                    end
                end
            end
        else
            if val <= 28
                if val <= 26
                    if val == 26
                        return fun(Val(26))
                    else
                        return fun(Val(25))
                    end
                else
                    if val == 28
                        return fun(Val(28))
                    else
                        return fun(Val(27))
                    end
                end
            else
                if val <= 30
                    if val == 30
                        return fun(Val(30))
                    else
                        return fun(Val(29))
                    end
                else
                    if val == 32
                        return fun(Val(32))
                    else
                        return fun(Val(31))
                    end
                end
            end
        end
    end
end

function valuedispatch(::Val{lower}, ::Val{upper}, fun, val) where {lower, upper}
    if lower >= upper
        return fun(Val(upper))
    end
    midpoint::Int = lower + div(upper - lower, 2)
    if val <= midpoint
        return valuedispatch(Val(lower), Val(midpoint), fun, val)
    else
        return valuedispatch(Val(midpoint+1), Val(upper), fun, val)
    end
end

function myfunc(::Val{N}) where N
    x = randn(SVector{N, Float64})
    return x' * x
end

N = rand(1:32, 10000)
@btime foreach(n -> myfunc(Val(n)), $N)
@btime foreach(n -> valuedispatch1to32(myfunc, n), $N)
@btime foreach(n -> valuedispatch(Val(1), Val(32), myfunc, n), $N)

julia> 3.364 ms (10000 allocations: 156.25 KiB)
julia> 895.042 μs (0 allocations: 0 bytes)
julia> 2.215 ms (10000 allocations: 156.25 KiB)

Could a recursive macro be used instead of a recursive function, to make the recursive version as fast as the hand-coded version?

Maybe of interest

1 Like

Ah, nice. Thanks! Unfortunately the switch statement approach ValSplit uses won’t be as fast. It has O(N) complexity rather than O(log2(N)) complexity of the binary search, N being the number of possible values. It is more general, though. I’m assuming the set of values is a contiguous set of integers.

How much and what kind of work is fun doing?

EDIT: Depending on what you’re doing, it’s likely that dynamically sized arrays are a better choice, but that something like BLAS or bad code generation is the problem.

Here is a slightly convoluted solution, using a generated function which uses a macro which calls a function to construct the expression:

function valuedispatch_expr(
        ::Val{lower}, ::Val{upper}, fun, val,
    ) where {lower, upper}
    if lower >= upper
        return :( $fun(Val($upper)) )
    end
    midpoint = lower + div(upper - lower, 2)
    expr_a = valuedispatch_expr(Val(lower), Val(midpoint), fun, val)
    expr_b = valuedispatch_expr(Val(midpoint+1), Val(upper), fun, val)
    quote
        if $val <= $midpoint
            $expr_a
        else
            $expr_b
        end
    end
end

macro valuedispatch(lower::Int, upper::Int, fun, val)
    valuedispatch_expr(Val(lower), Val(upper), esc(fun), esc(val))
end

@generated function valuedispatch_gen(
        ::Val{lower}, ::Val{upper}, fun, val,
    ) where {lower, upper}
    :( @valuedispatch($lower, $upper, fun, val) )
end

## Timings
N = rand(1:32, 10000)
@btime foreach(n -> myfunc(Val(n)), $N)
@btime foreach(n -> valuedispatch1to32(myfunc, n), $N)
@btime foreach(n -> valuedispatch(Val(1), Val(32), myfunc, n), $N)
@btime foreach(n -> valuedispatch_gen(Val(1), Val(32), myfunc, n), $N)

  5.518 ms (10000 allocations: 156.25 KiB)
  1.295 ms (0 allocations: 0 bytes)
  1.844 ms (10000 allocations: 156.25 KiB)
  1.282 ms (0 allocations: 0 bytes)
1 Like

One such function creates a small vector, and computes the norm. This is done millions of times. If you don’t know the size, the vector is allocated on the heap every time. That’s the performance issue.

Can you provide the actual code of that? With such a description one immediately thinks: do you need to instantiate the vector at all?

You mean on the heap every time?
You could try not allocating it on the heap every time.

Fantastic! Thank you. I assume that if you generate this function in two places, with the same bounds but a different function to be called, then it reuses the generated code?

This is precisely why I don’t like to go into details. :laughing:
I’m experienced enough to know the design is the way it needs to be. And it will take me too long to explain the whole thing so you reach the same conclusion. So I’d prefer not to.

1 Like

Sorry, just remembered some other things that make it yet more complicated. The functions need to be autodiffed with ForwardDiff, with a variable number of variables. That really needs compile time dispatch to be efficient. Trust me, it’s not so simple.

I do feel this is moving away from the topic I was interested in, which was about runtime dispatch speeds being slow, and how to fix that.

The xy problem is common when people ask questions.
I was worried about your code taking an eternity to compile, being unreasonably bloated, and achieving worse runtime performance than would otherwise be possible.

However, we shouldn’t let perfect be the enemy of the good, and if a perfect solution takes too much effort, it’s unlikely to be worth it.

1 Like

Undoubtedly. I don’t think that’s the problem here, though. Quite the opposite. However, I regret to say I don’t know the label for my problem. But I’m simply trying to save both you and myself time. In any case, I asked a question I’m genuinely interested in.

I see. I haven’t noticed bloat or slow compilation. My main focus is writing a package that is easy to use, and that achieves excellent runtime performance.

Very true. For example, I haven’t got time to implement an autodiff framework that would play nicely with the particular approach you have in mind. And I also don’t think that would be a good use of my time.

I’ve never used Julia macros before, so I find this approach very interesting. What is the difference between what this does and what my code did? Does this generate an unrolled function? Or does it simply somehow avoid the allocations that were happening in my code?

Actually, I think I understand why the allocations are happening. The compiler cannot tell what type the output of the recursive function will be. I need to be able to provide a hint to the compiler that the output type of valuedispatch is the same as the output type of the input argument function handle, fun. Is there a way to do that?

The macro is generating more or less exactly the same function valuedispatch1to32 that you created manually.

To understand what is going on, you can start by directly calling my valuedispatch_expr function, which returns an expression. Note that, in contrast to your approach (your valuedispatch function), this function doesn’t compute anything, it just builds some code that can be later evaluated:

julia> ex = valuedispatch_expr(Val(1), Val(4), :somefunction, :somevalue)
quote
    if somevalue <= 2
        begin
            if somevalue <= 1
                somefunction(Val(1))
            else
                somefunction(Val(2))
            end
        end
    else
        begin
            if somevalue <= 3
                somefunction(Val(3))
            else
                somefunction(Val(4))
            end
        end
    end
end

julia> typeof(ex)
Expr

(Here I’m using upper = 4 just to keep things shorter.)

The valuedispatch macro just evaluates the constructed expression. One can use @macroexpand to see what’s going on when one calls the macro:

julia> n = 3
3

julia> @valuedispatch(1, 4, myfunc, n)
0.8017024060007463

julia> @macroexpand @valuedispatch(1, 4, myfunc, n)
quote
    if n <= 2
        begin
            if n <= 1
                myfunc(Main.Val(1))
            else
                myfunc(Main.Val(2))
            end
        end
    else
        begin
            if n <= 3
                myfunc(Main.Val(3))
            else
                myfunc(Main.Val(4))
            end
        end
    end
end

Finally, the @generated function valuedispatch_gen is there to make sure that the @valuedispatch macro receives the actual integer values of lower and upper instead of some Julia symbols. Note that the following alternative doesn’t work because lower and upper are not integers by the time they are parsed by the macro:

julia> f(::Val{lower}, ::Val{upper}, fun, val) where {lower, upper} = @valuedispatch(lower, upper, fun, val)
ERROR: LoadError: MethodError: no method matching var"@valuedispatch"(::LineNumberNode, ::Module, ::Symbol, ::Symbol, ::Symbol, ::Symbol)
Closest candidates are:
  var"@valuedispatch"(::LineNumberNode, ::Module, ::Int64, ::Int64, ::Any, ::Any) at ~/tmp/julia/valdispatch/valdispatch.jl:191
in expression starting at REPL[26]:1
2 Likes

Great! Thank you so much.

My lingering impression from this result is that the system Julia currently uses to perform dynamic dispatch could be much faster.

There is a discussion about that here: Union splitting vs C++ - #13 by ChenNingCong

(see some previous posts).

I don’t know if there are specific proposals on how that can be improved, but there seem to be inherent limitations relative to the approach taken in static languages.

1 Like

The answer to the first question at the top is “hidden type instability”.
As mentioned in @lmiq’s reference this can also be resolved with vtables, so here’s an related, alternative approach to the problem at hand, using type trickery instead of metaprogramming:

using FunctionWrappers
struct FT{N}
end
function (ft::FT{N})() where {N}
    x = randn(SVector{N, Float64})
    return x' * x
end

const FW = FunctionWrapper{Float64,Tuple{}}
const tbl = FW[FW(FT{n}()) for n in 1:32]
function table_dispatch(n)
    f = tbl[n]
    f()
end
1 Like