Static call at runtime

This problem comes from several pieces of code I’ve been struggling to optimize due to missunderstanding of Julia compiler. I would like to know if there are any “standard” way to do this.

Consider a function f that does a complex computation, let say

function f(::Val{I}) where I
    return I^2
end

Assume that you don’t want to compute this I^2 at run time, but to cache the results. I know there might be other ways to do this without using Vals but there are some problems where we might want to do this at compile time.

Let’s say for example we want to compute

function g(x)
    s = 0
    for i in (1:x)
        u = i%10+1
        s += f(Val(u))
    end
    return s
end

The issue here is that we evaluate u at runtime, so the compiler does not know in advance which f it will have to call. This is why this leads to poor performance with many allocations:

julia> @btime g(23)
  3.657 μs (8 allocations: 128 bytes)

This post about piping and inline provides the solution of inlining f, so let’s try this:

@inline function f_inlined(::Val{I}) where I
    return I^2
end
function g_inlined(x)
    s = 0
    for i in (1:x)
        u = i%10+1
        s += f_inlined(Val(u))
    end
    return s
end

julia> @btime g_inlined(23)
  459.452 ns (8 allocations: 128 bytes)

This is much better. However, there is still a huge performance gap between this and what we can do if we knew all the values of i^2 for i \in \{1, \dots, 10\}.

So the solution I found is this: if we know that I takes values only in a a small constant sorted set fixed in advance, like (1:10), we can expand the code at parse time to do binary search with if/else statement, and call the concrete value that we want, e.g. f(Val(1)), …, f(Val(10)) in each one of the 10 sub-sub-sub … cases.

Here is the macro do do this (thank you Claude):

macro bounded_call(f, I, Range)
    # Evaluate Range at macro expansion time
    range_vals = eval(Range)

    function build_if_tree(indices, start_idx, stop_idx)
        if stop_idx == start_idx
            # Base case: single value
            val = indices[start_idx]
            return :($f(Val($val)))
        else
            # Binary split
            mid_idx = start_idx + (stop_idx - start_idx) ÷ 2
            mid_val = indices[mid_idx]

            left = build_if_tree(indices, start_idx, mid_idx)
            right = build_if_tree(indices, mid_idx + 1, stop_idx)

            return quote
                if $I <= $mid_val
                    $left
                else
                    $right
                end
            end
        end
    end

    if length(range_vals) == 0
        return :(error("Empty range"))
    elseif length(range_vals) == 1
        val = range_vals[1]
        return esc(:($f(Val($val))))
    else
        tree = build_if_tree(collect(range_vals), 1, length(range_vals))
        return esc(tree)
    end
end

This gives:

const MyRange = collect(1:10)
function g_static(x)
    s = 0
    for i in (1:x)
        u = i%10+1
        s += @bounded_call(f, u, MyRange)
    end
    return s
end
@btime g_static(23)
  20.648 ns (0 allocations: 0 bytes)

much better.

Maybe I’m rediscovering the wheel, but I’ve been confused a lot with this subtle use of Val. So any standard way of calling a static function with runtime value which we know belongs to a small set of fixed sized ?

Also, it’s pretty funny to try larger static values for MyRange, like Collect(1:10000). The parse time and compile time become much slower (we parse 10000 if/else statements!), but runtime stays the same. However, do not try much larger values, otherwise your editor will crash…

There’s no way to get around some kind of lookup table in that case (in any language I dare say).
Luckily, julia makes this easy:


@generated function f_inlined(i::Int, ::Val{NCached}) where NCached
    exprs = [:(i === $n && return $(n^2)) for n in 1:NCached]
    quote
        $(exprs...)
        return i^2  # fallback for i > NCached
    end
end

function g_inlined(x)
    s = 0
    for i in (1:x)
        u = i % 10 + 1
        s += f_inlined(u, Val(10))
    end
    return s
end
@btime g_inlined($(23))

18.537 ns (0 allocations: 0 bytes)

3 Likes

In essence, the issue is that the compiler doesn’t know that for

for i in 1:x
     u = i % 10 + 1

u can be a set of constants. I have a PR that would allow the compiler to represent sets of constants (but that has been on a back burner for a while). Currently the Julia compiler can only represent single Const.

As @sdanisch said one way is to use a generated and one of my favorite is ntuple

function f(I)
   return I^2
end

function g(x)
   s = Ref(0)
   for i in (1:x)
       u = i%10+1
       ntuple(Val(10)) do j
          if u == j
            s[] += f(j)
          end
          nothing
       end
   end
   return s[]
end
julia> @btime g_inlined(23)
  21.374 ns (0 allocations: 0 bytes)
799

julia> @btime g(23)
  1.883 ns (0 allocations: 0 bytes)
799

The reason why ntuple(Val) works is that it is basically a unrolled loop.

2 Likes

Thank you to both of you ! This is really helpful. Exactly what I need. And yeah @vchuravy this is especially important on GPUs where it produces compilation errors.

To push this a little bit further even if I do not need this for my use cases (and I think this is not very important, this is just for fun), imagine you have larger Ncached=1000
The problem of sequentially testing in your function is that for values close to 1000 the computing time will be much larger.

The function of @sdanisch achieves

@generated function f(i::Int, ::Val{NCached}) where NCached
    exprs = [:(i == $n && return $(n^2)) for n in 1:NCached]
    quote
        $(exprs...)
        return i^2  # fallback for i > NCached
    end
end
function g(x)
    s = 0
    for i in (1:x)
        u = i % 1000 + 1
        s += f(u, Val(1024))
    end
    return s
end
julia> @btime g(1000)
  2.970 μs (0 allocations: 0 bytes)

A function based on binary tree checking achieves

@generated function f_tree(i::Int, ::Val{NCached}) where NCached
    function build_tree(range)
        if length(range) == 0
            return nothing
        elseif length(range) == 1
            n = range[1]
            return :(return $(n^2))
        else
            mid = (first(range) + last(range)) ÷ 2
            left_range = first(range):mid
            right_range = (mid+1):last(range)

            left_tree = build_tree(left_range)
            right_tree = build_tree(right_range)

            if right_tree === nothing
                return :((i <= $mid) && $left_tree)
            else
                return :(((i <= $mid) && $left_tree; $right_tree))
            end
        end
    end

    tree = build_tree(1:NCached)

    return quote
        $tree
        return i^2
    end
end
function g_tree(x)
    s = 0
    for i in (1:x)
        u = i % 1000 + 1
        s += f_tree(u, Val(1024))
    end
    return s
end
@btime g_tree(1000)
  3.239 μs (0 allocations: 0 bytes)

This is slightly worse, do you know why ? It shouldn’t be because we make in theory \log_2(1000)*1000\approx 9966 checks in total, while the sequential generated function makes 1+ \cdot + 1000 \approx 500000 checks in total.
Also, if it turns out we only use last values around 950, the g_tree is better:

function g_worst(x)
    s = 0
    for i in (1:x)
        u = i % 5 + 990
        s += f(u, Val(1000))
    end
    return s
end
function g_tree_worst(x)
    s = 0
    for i in (1:x)
        u = i % 5 + 990
        s += f_tree(u, Val(1000))
    end
    return s
end
julia> @btime g_worst(1000)
  2.731 μs (0 allocations: 0 bytes)
julia> @btime g_tree_worst(1000)
  786.620 ns (0 allocations: 0 bytes)

and in fact, curiously, g_tree_worst is also better if you replace 990 by 0… which is very confusing in comparison to the above result with i%1000.

If you check the compiled code:

julia> @code_native f(1, Val(32)) # easier to read than 1024

you will see that the ifs were compiled down to a jump table.

Meanwhile the tree code looks like you would expect (see @code_native f_tree(1, Val(32))). So is much longer and much more complicated which actually matters to modern CPUs. So it actually surprises me that it is so close in performance.

EDIT: If you enforce inlining of f the performance difference vanishes for me:

julia> function g_inline_worst(x)
           s = 0
           for i in (1:x)
               u = i % 5 + 990
               s += @inline f(u, Val(1000))
           end
           return s
       end
g_inline_worst (generic function with 1 method)

julia> @btime g_worst(1000)
  5.855 μs (0 allocations: 0 bytes)
984066000

julia> @btime g_tree_worst(1000)
  1.041 μs (0 allocations: 0 bytes)
984066000

julia> @btime g_inline_worst(1000)
  973.438 ns (0 allocations: 0 bytes)
984066000
2 Likes

Ok, thanks a lot. I always struggle to understand when it’s necessary to @inline or not. It clearly shows that the normal version is better than the tree version.

However, (this is super anecdotic I guess) for values that are larger, LLVM splits the cases into several switch tables with 128 or 256 elements, and it looks like the compiler determines which table contains the target index through sequential range checks rather than binary search.

So I just tried trees with NCached total leaves and NLeaves leaves that have same parent. In particular, NLeaves=2 corresponds to binary tree (g_tree) and NLeaves=NCached corresponds to sequential if/else statements (g).

When NLeaves=512, I get optimal performance on my machine:

@inline @generated function f_branched(i::Int, ::Val{NCached}, ::Val{NLeaves}) where {NCached,NLeaves}
    NBranch = NCached ÷ NLeaves

    function build_branch(offset, n_items)
        # Generate sequential checks for this branch
        exprs = [:(i == $k && return $(k^2)) for k in offset:(offset+n_items-1)]
        return Expr(:block, exprs...)
    end

    function build_tree(branch_idx, total_branches, leaves_per_branch)
        if branch_idx > total_branches
            return nothing
        elseif branch_idx == total_branches
            # Last branch - just generate the checks
            offset = (branch_idx - 1) * leaves_per_branch + 1
            return build_branch(offset, leaves_per_branch)
        else
            # Not last branch - use if/else to split
            mid_value = branch_idx * leaves_per_branch
            offset = (branch_idx - 1) * leaves_per_branch + 1

            left_checks = build_branch(offset, leaves_per_branch)
            right_tree = build_tree(branch_idx + 1, total_branches, leaves_per_branch)

            return quote
                if i <= $mid_value
                    $left_checks
                else
                    $right_tree
                end
            end
        end
    end

    tree = build_tree(1, NBranch, NLeaves)

    return quote
        $tree
        return i^2  # Fallback for values outside cached range
    end
end


f_branched(3, Val(256), Val(128))

function g(x)
    s = 0
    for i in (1:x)
        u = i % 5000+1
        s += @inline f(u, Val(1024*8))
    end
    return s
end
function g_tree(x)
    s = 0
    for i in (1:x)
        u = i % 5000+1
        s += @inline f_tree(u, Val(1024*8))
    end
    return s
end
function g_branched(x)
    s = 0
    for i in (1:x)
        u = i % 5000+1
        s += @inline f_branched(u, Val(1024*8), Val(512))
    end
    return s
end
julia> @btime g(5000)
@btime g_tree(5000)
  43.846 μs (0 allocations: 0 bytes)
41679167500

julia> @btime g_tree(5000)
  48.247 μs (0 allocations: 0 bytes)
41679167500

julia> @btime g_branched(5000)
  8.498 μs (0 allocations: 0 bytes)
41679167500

I guess it hasn’t been optimized in LLVM for such large values because this is super useless in practice…

Btw, just because no one has addressed the ‘correct’ solution to your original problem so far: the canonical solution to your original problem (not recomputing expensive results) is memoization, which basically uses a dictionary to store computed values. This is flexible, robus and very convenient. Have a look at Memoize.jl :slight_smile: