To get the best performance across different inputs, we often have multiple implementations, and choose which one to use based on the input given. For example, for array traversals we may want to unroll loops for small arrays, but not for larger ones. So we’d have an implementation of each, and choose based on some cutoff.
This gets awkward, because the optimal cutoff depends not only on the algorithms, but on the machine running the code. So what we’d really like is a way to tune this, ideally without the end-user needing to worry about it.
I found a way to do this by entirely misusing generated functions. But… it seems to work? Here’s the code, with f
and g
chosen to actually do some timing, to show that this is possible:
using Static: StaticInt
# Cheaper for small values
function f(x)
t0 = time_ns()
t1 = t0 + 50x
@elapsed while time_ns() < t1
end
end
# Cheaper for large values
function g(x)
t0 = time_ns()
t1 = t0 + 1000
@elapsed while time_ns() < t1
end
end
# Set initial bounds
(lo, hi) = (0, 100)
# Warm up, to avoid measuring compile time
f(lo)
g(lo)
f(hi)
g(hi)
function bisect(f, g, lo::Int, hi::Int)
f_lo = f(lo)
f_hi = f(hi)
g_lo = g(lo)
g_hi = g(hi)
while (hi - lo) > 1
x = round(Int, (lo + hi) / 2)
fx = f(x)
gx = g(x)
if fx < gx
(lo, f_lo) = (x, fx)
else
(hi, f_hi) = (x, fx)
end
end
return lo
end
@generated function cutoff(::F, ::G, ::StaticInt{lo}, ::StaticInt{hi}) where {F,G, lo, hi}
# Recover functions from their types
f = F.instance
g = G.instance
# Find the best value
x = StaticInt(bisect(f, g, lo, hi))
quote
# Hard-code the result
$x
end
end
And the result:
julia> cutoff(f, g, StaticInt(lo), StaticInt(hi))
static(19)
julia> @code_warntype cutoff(f, g, StaticInt{lo}(), StaticInt{hi}())
MethodInstance for cutoff(::typeof(f), ::typeof(g), ::StaticInt{0}, ::StaticInt{100})
from cutoff(::F, ::G, ::StaticInt{lo}, ::StaticInt{hi}) where {F, G, lo, hi} in Main at REPL[9]:1
Static Parameters
F = typeof(f)
G = typeof(g)
lo = 0
hi = 100
Arguments
#self#::Core.Const(cutoff)
_::Core.Const(f)
_::Core.Const(g)
_::Core.Const(static(0))
_::Core.Const(static(100))
Body::StaticInt{19}
1 ─ return static(19)
Is it plausible to do something like to auto-tune “switch points” for a particular machine? Would this break horribly? Is there a better way to do this?
EDIT: Better to return a ::Val
EDIT2: Better still, Static.StaticInt