I am trying to develop a package that needs to operate on small matrices, the size of which isn’t known at runtime. Nevertheless, there is a lot of performance to be gained by specializing/parameterizing particular functions on the size of these matrices, and using StaticArrays.
One way to do this is simply to use Val(n) and runtime dispatch. However, this in itself incurs a significant overhead. I can reduce this overhead significantly (possibly 10x) by using a binary search to turn the runtime dispatch into a compile time one. However, this requires the value range to be hardcoded, which isn’t ideal - it would be nice if it were general. So I also tried to write the binary search algorithm recursively. Unfortunately this wiped out most of the performance gains of the hard-coded binary search.
This raises some questions for me:
- Why is runtime dispatch so slow, when a simple binary search is relatively fast?
- Why does the recursive implementation, which should essentially be doing the same thing as the hard-coded implementation, run much slower, and incur some allocations?
- Given that the binary search approach seems useful, is this functionality already available anywhere currently?
Any thoughts?
MWE and benchmark results below:
using BenchmarkTools, StaticArrays
function valuedispatch1to32(fun, val)
if val <= 16
if val <= 8
if val <= 4
if val <= 2
if val == 2
return fun(Val(2))
else
return fun(Val(1))
end
else
if val == 4
return fun(Val(4))
else
return fun(Val(3))
end
end
else
if val <= 6
if val == 6
return fun(Val(6))
else
return fun(Val(5))
end
else
if val == 8
return fun(Val(8))
else
return fun(Val(7))
end
end
end
else
if val <= 12
if val <= 10
if val == 10
return fun(Val(10))
else
return fun(Val(11))
end
else
if val == 12
return fun(Val(12))
else
return fun(Val(11))
end
end
else
if val <= 14
if val == 14
return fun(Val(14))
else
return fun(Val(13))
end
else
if val == 16
return fun(Val(16))
else
return fun(Val(15))
end
end
end
end
else
if val <= 24
if val <= 20
if val <= 18
if val == 18
return fun(Val(18))
else
return fun(Val(17))
end
else
if val == 20
return fun(Val(20))
else
return fun(Val(19))
end
end
else
if val <= 22
if val == 22
return fun(Val(22))
else
return fun(Val(21))
end
else
if val == 24
return fun(Val(24))
else
return fun(Val(23))
end
end
end
else
if val <= 28
if val <= 26
if val == 26
return fun(Val(26))
else
return fun(Val(25))
end
else
if val == 28
return fun(Val(28))
else
return fun(Val(27))
end
end
else
if val <= 30
if val == 30
return fun(Val(30))
else
return fun(Val(29))
end
else
if val == 32
return fun(Val(32))
else
return fun(Val(31))
end
end
end
end
end
end
function valuedispatch(::Val{lower}, ::Val{upper}, fun, val) where {lower, upper}
if lower >= upper
return fun(Val(upper))
end
midpoint::Int = lower + div(upper - lower, 2)
if val <= midpoint
return valuedispatch(Val(lower), Val(midpoint), fun, val)
else
return valuedispatch(Val(midpoint+1), Val(upper), fun, val)
end
end
function myfunc(::Val{N}) where N
x = randn(SVector{N, Float64})
return x' * x
end
N = rand(1:32, 10000)
@btime foreach(n -> myfunc(Val(n)), $N)
@btime foreach(n -> valuedispatch1to32(myfunc, n), $N)
@btime foreach(n -> valuedispatch(Val(1), Val(32), myfunc, n), $N)
julia> 3.364 ms (10000 allocations: 156.25 KiB)
julia> 895.042 μs (0 allocations: 0 bytes)
julia> 2.215 ms (10000 allocations: 156.25 KiB)