This problem comes from several pieces of code I’ve been struggling to optimize due to missunderstanding of Julia compiler. I would like to know if there are any “standard” way to do this.
Consider a function f that does a complex computation, let say
function f(::Val{I}) where I
return I^2
end
Assume that you don’t want to compute this I^2 at run time, but to cache the results. I know there might be other ways to do this without using Vals but there are some problems where we might want to do this at compile time.
Let’s say for example we want to compute
function g(x)
s = 0
for i in (1:x)
u = i%10+1
s += f(Val(u))
end
return s
end
The issue here is that we evaluate u at runtime, so the compiler does not know in advance which f it will have to call. This is why this leads to poor performance with many allocations:
julia> @btime g(23)
3.657 μs (8 allocations: 128 bytes)
This post about piping and inline provides the solution of inlining f, so let’s try this:
@inline function f_inlined(::Val{I}) where I
return I^2
end
function g_inlined(x)
s = 0
for i in (1:x)
u = i%10+1
s += f_inlined(Val(u))
end
return s
end
julia> @btime g_inlined(23)
459.452 ns (8 allocations: 128 bytes)
This is much better. However, there is still a huge performance gap between this and what we can do if we knew all the values of i^2 for i \in \{1, \dots, 10\}.
So the solution I found is this: if we know that I takes values only in a a small constant sorted set fixed in advance, like (1:10), we can expand the code at parse time to do binary search with if/else statement, and call the concrete value that we want, e.g. f(Val(1)), …, f(Val(10)) in each one of the 10 sub-sub-sub … cases.
Here is the macro do do this (thank you Claude):
macro bounded_call(f, I, Range)
# Evaluate Range at macro expansion time
range_vals = eval(Range)
function build_if_tree(indices, start_idx, stop_idx)
if stop_idx == start_idx
# Base case: single value
val = indices[start_idx]
return :($f(Val($val)))
else
# Binary split
mid_idx = start_idx + (stop_idx - start_idx) ÷ 2
mid_val = indices[mid_idx]
left = build_if_tree(indices, start_idx, mid_idx)
right = build_if_tree(indices, mid_idx + 1, stop_idx)
return quote
if $I <= $mid_val
$left
else
$right
end
end
end
end
if length(range_vals) == 0
return :(error("Empty range"))
elseif length(range_vals) == 1
val = range_vals[1]
return esc(:($f(Val($val))))
else
tree = build_if_tree(collect(range_vals), 1, length(range_vals))
return esc(tree)
end
end
This gives:
const MyRange = collect(1:10)
function g_static(x)
s = 0
for i in (1:x)
u = i%10+1
s += @bounded_call(f, u, MyRange)
end
return s
end
@btime g_static(23)
20.648 ns (0 allocations: 0 bytes)
much better.
Maybe I’m rediscovering the wheel, but I’ve been confused a lot with this subtle use of Val. So any standard way of calling a static function with runtime value which we know belongs to a small set of fixed sized ?
Also, it’s pretty funny to try larger static values for MyRange, like Collect(1:10000). The parse time and compile time become much slower (we parse 10000 if/else statements!), but runtime stays the same. However, do not try much larger values, otherwise your editor will crash…