Constructing a Switching Function on the Fly

In ExpandNestedData, I have a module for composing many nested iterators. It works by boiling down repeat and cycle into some callable structs that do division and modulo functions (respectively) on a provided index to route it back to the correct index of a seed iterator. I then compose them together as I nest them. It works very well for my application, yay!

But I also need a similar solution for a lazy vcat. This one needs to take N child iterators, a requested index, and then route the index call to the correct child get_index function.

I can do something like

struct Unvcat{F,G}
(u::Unvcat)(i) = i <= u.f_len ? u.f(i) : u.g(i - u.f_len)

Then I can just compose the functions together and we’re good to go.

However, I can have hundreds or thousands of NestedIterators to stack, so composing this takes AGES.

So I thought I’d have a go at just building the switch function with metaprogramming and eval the result:

"""make_switch(fs, lengths)
Create a switching function that takes an integer `i` and compares it against 
each length provided in order. Once it accumulates the sum of lengths greater than or equal to 
`i`, it subtracts the previous total length and runs the corresponding function.
function make_switch(fs, lengths)
    func_def = compose_switch_body(fs,lengths)
    @eval $func_def

function compose_switch_body(fs,lengths)
    total_len = sum(lengths)

    _fs = Iterators.Stateful(fs)
    _lengths = Iterators.Stateful(lengths)
    l = popfirst!(_lengths)
    if_stmt = :(
        if i <= $(l) 
    curr_stmt = if_stmt
    prev_l = l
    for (f,l) in zip(_fs, _lengths)
        # insert a `elseif` for every subsequent function
        curr_l = l + prev_l
        ex = Expr(
            :(i <= $(curr_l)),
        push!(curr_stmt.args, ex)
        prev_l = curr_l
        curr_stmt = ex
    name = gensym("unvcat_switch")
    error_str = "Attempted to access $total_len-length vector at index "
    func_def = :(
        function $(name)(i)
            i > $(total_len) && error($error_str * "$i")
    return func_def

It works like a charm. Super fast to make, super fast to run. Feeling really great!

Aaaah, but World Age! Running iter.get_index(i) is now using an evaled function that isn’t in the world age yet.

My solution for now is to define a special iterate and collect versions that provide a function barrier that calls Base.invokelatest so iterating over all the values can be fast. But I’d love to provide the user with a way to use the fast function with the lazy iterator so they can avoid allocations if need be. For now, I have getindex(n::NestedIterator, i) = Base.invokelatest(n.get_index, i). Which is fine but obviously not ideal.

Are there better solutions I’m missing? Maybe a solution with generated functions? Maybe something to make composing functions faster?

PS, I also considered having a callable struct that holds a vector of functions and a vector of lengths, and then make it callable with a function that iterates over both vectors, but that is terribly type unstable since I can’t know the types of the functions. But maybe I’m overthinking that?

1 Like

If you use tuples, the compiler should unroll a lot of loops for you without resorting to generated functions etcetera. For example, maybe something like:

struct LazyVcat{I,S}
    itrs::I # tuple of iterators
    starts::S # tuple of starting indices
function lazyvcat(itrs...)
    starts = cumsum(length.(itrs)) .- length(first(itrs)) .+ 1
    return LazyVcat{typeof(itrs),typeof(starts)}(itrs, starts)

function Base.getindex(a::LazyVcat, i::Integer)
    N = length(a.itrs)
    for j in reverse(ntuple(identity, Val(length(a.itrs))))
        if i ≥ a.starts[j]
            return a.itrs[j][i - a.starts[j] + firstindex(a.itrs[j])]
    throw(BoundsError(a, i)) # i < 1

which gives

julia> a = lazyvcat(["a","b","c"], ["d","e","f"]);

julia> [a[i] for i in 1:6]
6-element Vector{String}:

and is type-stable if all of the itrs have the same inferred element type.

Of course, this makes getindex O(N) where N is the number of iterators/arrays you have concatenated. You could get it down to O(\log N) by using a binary search, I guess. Would take a bit more work to inline the binary search for a fixed N; might require a generated function unless someone has already optimized a searchsortedlast for tuples (julia#50789).

1 Like

I didn’t consider using tuples due to the potentially large number of iters, but I’ll give it a go!

How many iterators are you lazily concatenating? If it’s really big then you won’t want to inline the code either, and you’ll definitely want a log(N) binary search or similar.

One of the use cases is stacking json records from an API. Those records could be in the thousands. I did not consider a binary search until you said it, so I’m going to give that a shot when I get back to my computer!

In that case a tuple and/or code generation (via a static compile-time length) are probably the wrong approach. Just use an array (dynamic length) under the hood, and use searchsortedlast (which does binary search) on an array of starting indices.

1 Like