I am trying to develop a package that needs to operate on small matrices, the size of which isn’t known at runtime. Nevertheless, there is a lot of performance to be gained by specializing/parameterizing particular functions on the size of these matrices, and using StaticArrays.
One way to do this is simply to use Val(n) and runtime dispatch. However, this in itself incurs a significant overhead. I can reduce this overhead significantly (possibly 10x) by using a binary search to turn the runtime dispatch into a compile time one. However, this requires the value range to be hardcoded, which isn’t ideal - it would be nice if it were general. So I also tried to write the binary search algorithm recursively. Unfortunately this wiped out most of the performance gains of the hard-coded binary search.
This raises some questions for me:
- Why is runtime dispatch so slow, when a simple binary search is relatively fast?
- Why does the recursive implementation, which should essentially be doing the same thing as the hard-coded implementation, run much slower, and incur some allocations?
- Given that the binary search approach seems useful, is this functionality already available anywhere currently?
Any thoughts?
MWE and benchmark results below:
using BenchmarkTools, StaticArrays
function valuedispatch1to32(fun, val)
    if val <= 16
        if val <= 8
            if val <= 4
                if val <= 2
                    if val == 2
                        return fun(Val(2))
                    else
                        return fun(Val(1))
                    end
                else
                    if val == 4
                        return fun(Val(4))
                    else
                        return fun(Val(3))
                    end
                end
            else
                if val <= 6
                    if val == 6
                        return fun(Val(6))
                    else
                        return fun(Val(5))
                    end
                else
                    if val == 8
                        return fun(Val(8))
                    else
                        return fun(Val(7))
                    end
                end
            end
        else
            if val <= 12
                if val <= 10
                    if val == 10
                        return fun(Val(10))
                    else
                        return fun(Val(11))
                    end
                else
                    if val == 12
                        return fun(Val(12))
                    else
                        return fun(Val(11))
                    end
                end
            else
                if val <= 14
                    if val == 14
                        return fun(Val(14))
                    else
                        return fun(Val(13))
                    end
                else
                    if val == 16
                        return fun(Val(16))
                    else
                        return fun(Val(15))
                    end
                end
            end
        end
    else
        if val <= 24
            if val <= 20
                if val <= 18
                    if val == 18
                        return fun(Val(18))
                    else
                        return fun(Val(17))
                    end
                else
                    if val == 20
                        return fun(Val(20))
                    else
                        return fun(Val(19))
                    end
                end
            else
                if val <= 22
                    if val == 22
                        return fun(Val(22))
                    else
                        return fun(Val(21))
                    end
                else
                    if val == 24
                        return fun(Val(24))
                    else
                        return fun(Val(23))
                    end
                end
            end
        else
            if val <= 28
                if val <= 26
                    if val == 26
                        return fun(Val(26))
                    else
                        return fun(Val(25))
                    end
                else
                    if val == 28
                        return fun(Val(28))
                    else
                        return fun(Val(27))
                    end
                end
            else
                if val <= 30
                    if val == 30
                        return fun(Val(30))
                    else
                        return fun(Val(29))
                    end
                else
                    if val == 32
                        return fun(Val(32))
                    else
                        return fun(Val(31))
                    end
                end
            end
        end
    end
end
function valuedispatch(::Val{lower}, ::Val{upper}, fun, val) where {lower, upper}
    if lower >= upper
        return fun(Val(upper))
    end
    midpoint::Int = lower + div(upper - lower, 2)
    if val <= midpoint
        return valuedispatch(Val(lower), Val(midpoint), fun, val)
    else
        return valuedispatch(Val(midpoint+1), Val(upper), fun, val)
    end
end
function myfunc(::Val{N}) where N
    x = randn(SVector{N, Float64})
    return x' * x
end
N = rand(1:32, 10000)
@btime foreach(n -> myfunc(Val(n)), $N)
@btime foreach(n -> valuedispatch1to32(myfunc, n), $N)
@btime foreach(n -> valuedispatch(Val(1), Val(32), myfunc, n), $N)
julia> 3.364 ms (10000 allocations: 156.25 KiB)
julia> 895.042 μs (0 allocations: 0 bytes)
julia> 2.215 ms (10000 allocations: 156.25 KiB)
