How to "unroll" getindexes to a Tuple parametrically?

Consider the following code:

abstract type Op end

struct Add{T} <: Op
    foo::T
end

struct Mul{T} <: Op
    foo::T
end

op(o::Add, value) = o.foo + value
op(o::Mul, value) = o.foo * value

abstract type AbstractCalculator end

struct Calculator{O} <: AbstractCalculator where {O<:Tuple{Op}}
    ops::O
end

c = Calculator((Add(5), Mul(2)))

function run_op(calculator::C, op_index, value) where {C<:Calculator}
    o = calculator.ops[op_index]
    op(o, value)
end

using BenchmarkTools
@benchmark run_op(c, 1, 3.5)
@code_warntype run_op(c, 1, 3.5)

function run_op_manual(calculator::Calculator, op_index, value)
    if op_index == 1
        return op(calculator.ops[1], value)
    elseif op_index == 2
        return op(calculator.ops[2], value)
    else
        error("boink")
    end
end

@benchmark run_op_manual(c, 1, 3.5)
@code_warntype run_op_manual(c, 1, 3.5)

When calling run_op, the o variable is naturally type unstable and causes heap allocations, since it can take one of two potential types depending on the value of the op_index argument and not just its type. This seems reasonable enough to me. Alternatively, a manually “unrolled” version, as in run_op_manual, is type-stable, less allocating and faster as a consequence. Of course, the manual version only works for this specific Calculator type - if there is either a different number of ops or they have a different type, the second version will fail either explicitly or by returning wrong results. Nevertheless, all the information required to build this version seems to be contained on the Calculator type signature, suggesting that you could build something similar at compilation time for every signature that is called (perhaps through metaprogramming?). Am I right in concluding that? If that’s the case, what would be the best way to do it?

If op_index is never computed but is constant, you could use Vals, e.g., replace

with

function run_op(calculator::C, i::Val{op_index}, value) where {C<:Calculator,op_index}

and then call it like run_op(c, Val(1), 3.5). With that change (to both functions), the benchmarks are very similar. Once you have to iterate over each element of calculator.ops, though, it seems not to work so well.

If your goal is to chain the computations (i.e., something like op(ops[2], op(ops[1], value))), then you could use reduce:

function run_reduce(calculator::Calculator, value)
    reduce(calculator.ops; init = Add(value)) do o1, o2
        Add(op(o2, o1.foo))
    end.foo
end

(Admittedly, this function looks a bit strange due to somewhat randomly having to wrap all the values in an Op. With some refactoring of your structs it could probably be made to look better.)

I’ve never worked with @generated functions before, so I took this opportunity to try it out. Here’s what I came up with:

@generated function run_op_manual_generated(calculator::Calculator{<:NTuple{N,Op}}, op_index, value) where {N}
    N < 1 && return :(error("boink"))
    ex = Expr(:if, :(op_index == 1), :(return op(calculator.ops[1], value)))
    args = ex.args
    for i = 2:N
        push!(args, Expr(:elseif, :(op_index == $i), :(return op(calculator.ops[$i], value))))
        args = args[3].args
    end
    push!(args, :(error("boink")))
    return ex
end

This seems to run just as fast as run_op_manual and is type stable (even for larger Calculators; I tried 26 operations, maybe it breaks if you go higher).

Hopefully something I’ve said is helpful.


By the way,

is probably not doing what you think it is. As written, the where {O<:Tuple{Op}} actually does nothing, because there is no type parameter O in AbstractCalculator. (The O in Calculator{O} is not the same O). You probably meant

struct Calculator{O<:NTuple{N,Op} where N} <: AbstractCalculator

(The NTuple{N,Op} is necessary to allow an arbitrary number of Ops; Tuple{Op} allows just one, i.e., Tuple{Op} == NTuple{1,Op}.)

2 Likes

That’s very helpful indeed, massive thanks! I think the @generated version is the one that I’m going to try on my original problem. I had tried using Val{...} before, but (as you note) it did not solve my problem since the function is called on a hot ODE loop and the integer flag is not constant.

For posterity, I found another metaprogramming solution using @nexprs that seems to be just as fast as the generated version while being a bit more compact and legible:

using Base.Cartesian: @nexprs
@generated function run_op_nexprs(calculator::Calculator{<:NTuple{N,Op}}, op_index, value) where {N}
    quote
        @nexprs $N j -> ((op_index == j) ? (return op(calculator.ops[j], value)) : nothing)
        return error("boink")
    end
end