[ANN] Catwalk.jl - With dynamic dispatch to the moon! (an adaptive optimizer, aka JIT compiler)

I ended up implementing one of my pipe dreams, and hey, it works! :smiley:

Catwalk.jl Intro

Catwalk.jl can speed up long-running Julia processes by minimizing the
overhead of dynamic dispatch. It is a JIT compiler that continuosly
re-optimizes dispatch code based on data collected at runtime.

Speedup demo
source code of this test

It profiles user-specified call sites, estimating the distribution of
dynamically dispatched types during runtime, and generates fast
static routes for the most frequent ones on the fly.

The statistical profiler has very low overhead and can be configured
to handle situations where the distribution of dispatched types
changes relatively fast.

To minimize compilation overhead, recompilation only occurs when the
distribution changed enough and the tunable cost model predicts
significant speedup compared to the best version that was previously
compiled.

When to use this package

The dynamic dispatch in Julia is very fast in itself, so speeding it up is not an easy task.
Catwalk.jl focuses on use cases when it is not feasible to list the dynamically dispatched concrete types in the source code of the call site.

Catwalk.jl assumes the followings:

  • The process is long running: several seconds, but possibly minutes are needed to break even after the initial compilation overhead.
  • Few dynamically dispatched call sites contribute significantly to the running time (dynamic dispatch in a hot loop).
  • You can modify the source code around the interesting call sites (add a macro call), and calculation is organized into batches.
Alternative packages

Alternatives

Usage

Let’s say you have a long-running calculation, organized into batches:

const NUM_BATCHES = 1000

function runbatches()
    for batchidx = 1:NUM_BATCHES
        hotloop()
        # Log progress, etc.
    end
end

The hot loop calls the type-unstable function get_some_x() and passes its result to a relatively cheap calculation calc_with_x() .

const NUM_ITERS_PER_BATCH = 1_000_000

function hotloop()
    for i = 1:NUM_ITERS_PER_BATCH
        x = get_some_x(i)
        calc_with_x(x)
    end
end

const xs = Any[1, 2.0, ComplexF64(3.0, 3.0)]
get_some_x(i) = xs[i % length(xs) + 1]

const result = Ref(ComplexF64(0.0, 0.0))

function calc_with_x(x)
    result[] += x
end

As get_some_x is not type-stable, calc_with_x must be dynamically dispatched, which slows down the calculation.

Sometimes it is not feasible to type-stabilize get_some_x . Catwalk.jl is here for those cases.

You mark hotloop * , the outer function with the @jit macro and provide the name of the dynamically dispatched function and the argument to operate on (the API will hopefully improve in the future). You also have to add an extra argument named jitctx to the jit-ed function:

using Catwalk

@jit calc_with_x x function hotloop_jit(jitctx)
    for i = 1:NUM_ITERS_PER_BATCH
        x = get_some_x(i)
        calc_with_x(x)
    end
end

The Catwalk optimizer will provide you the jitctx context which you have to pass to the jit-ed function manually. Also, every batch needs a bit housekeeping to drive the Catwalk optimizer:

function runbatches_jit()
    jit = Catwalk.JIT() ## Also works inside a function (no eval used)
    for batch = 1:NUM_BATCHES
        Catwalk.step!(jit)
        hotloop_jit(Catwalk.ctx(jit))
    end
end

Yes, it is a bit complicated to integrate your code with Catwalk, but it may worth the effort:

result[] = ComplexF64(0, 0)
@time runbatches_jit()

# 4.608471 seconds (4.60 M allocations: 218.950 MiB, 0.56% gc time, 21.68% compilation time)

jit_result = result[]

result[] = ComplexF64(0, 0)
@time runbatches()

# 23.387341 seconds (1000.00 M allocations: 29.802 GiB, 7.71% gc time)

And the results are the same:

jit_result == result[] || error("JIT must be a no-op!")

Please note that the speedup depends on the portion of the runtime spent in dynamic dispatch, which is most likely smaller in your case than in this contrived example.

Source of this demo: usage.jl
What’s inside: How it works? · Catwalk.jl
Fully tunable: Configuration & tuning ¡ Catwalk.jl

* EDIT: clarification: The name hotloop for this function is misleading. There is a hot loop somewhere in the code, but it is possible that the function marked whit the @jit macro is only called from the loop body, meaning that the macro is not aware of the loop. Possibly more than one jit-ed functions are called from the same loop (example).

80 Likes

This looks awesome, good job!

1 Like

FYI, just as additions to Alternative packages, FoldsThreads.jl’s “trampoline” also uses the same technique (IIUC) https://github.com/JuliaFolds/FoldsThreads.jl/blob/master/src/trampoline.jl. JuliaFolds packages in general also try to do a weaker version of this, as discussed in: Tail-call optimization and function-barrier -based accumulation in loops

(But it looks like you arrived at a more generic API/implementation, which is great!)

12 Likes

Thanks, I was not aware of that, now added a link to JuliaFolds!

Selecting the best solution for a use case seems hard, and honestly I think that a lot of scientific calculations struggling with similar issues may not need the dynamism provided by Catwalk.jl, and go better with alternatives. (My motivation was a long running server with varying load.)

4 Likes

Yeah, I think we have a big space to explore there. For example, I had to tune down the “JIT” optimization to reduce compilation latency in Transducers.jl. But in principle, it could be configurable (like you are doing with the optimizer/cost model). Extending this idea, I think an interesting thing to implement is to hook a more aggressive type-stabilization mechanism like the one used by Catwalk into JuliaFolds executor API so that it can be composed with JuliaFolds-based programs (including parallel ones) and lets you write a bit more naturalistic for loop syntax via @floop macro, and it can be done without adding anything extra to the user-side code.

At the end of how it works, you are discussing that the the generator is looking up a global dict. I wonder if you can just attach the state to the loop state (“accumulator”) of the loop you are optimizing by transforming the loop body (which is actually what transducer + foldl is all about). I imagine it’d make stats update race free (“thread safe”) and gets rid of the “fishy” part of @generated usage.

1 Like

I have just started digesting your work, so all I can say for now is that I would be very glad if my work would inspire you a bit! Both Transducers.j and Floops.jl was long on my list to check out closely, but my interest is not in data-parallelisation, so unfortunately most of your work seems not applicable directly in my case. I still need to learn a lot though, and reading your code will definitely improve my craft.

I am not sure I fully understand you, there may be a misunderstanding: The stats collection is using a “local” dict (part of the context in profiled batches), the global dict is only used for “exploration”, which is more or less a convenience feature: It connects the macro-marked call sites to the JIT compiler which runs in the outer, “batch” loop. Exploration can be turned off completely if the jit-ed call sites are listed at the initialization of the JIT compiler. It sounds a bit strange (Why creating an auto-configuration mechanism for that?), and the answer is that it was accidentally developed, I thought I can do it easily, without globals and executed fully in compile time, but it turned out to be seemingly impossible. Maybe I should remove it and force the user to always configure the compiler. I will see, it is too early to decide I think.

Please also note, that the loop itself is not necessarily visible to the macro: In the framework I wrote Catwalk for, there is already an abstraction layer (function call) between the inner loop and the jit-ed call, and there is also a composition of two jit-ed functions (Something like in this test: https://github.com/tisztamo/Catwalk.jl/blob/0ed6d8822ddf7aa4935a062cfae753fabfc9b036/test/scheduling.jl#L28). I am currently working on adding another layer of abstraction there with Plugins.jl that uses generics to compile codes from different parts of the application (“plugins”) together into the hot loop, so the macro will have no way to know where the loop is.

Edit: added a clarification to the end of the OP.

2 Likes

Good to know that they are on your radar! :slight_smile: But let me emphasize that they both started as an alternative conceptualization of sequential loops. You can probably just look at the sequential loop _foldl_iter implementation (and just assume @next(rf, val, x) to be just rf(val, x)) to see how it is “JIT”-ing [1] the loop. For each recursion, I can insert if val isa T1 ... elseif val isa T2 ..., as it’s explained in your documentation, to make it more powerful.

(The reason why I moved to emphasize data-parallelism is that this approach is powerful enough also to capture the transformations required for parallelism.)

Thanks for the clarification. Now that I look at the documentation again, it’s clearly in the “subsection” explaining the exploration feature.

Don’t you still need to have a “batched” inner loop inside of the function that lifts the frequency information to the type domain? That is to say, the user needs to “batch” the loop.

In contrast to that, I think Transducers.jl has enough scaffold to remove the manual/intrusive batching in the user-code. I think it’s possible to provide an API like

const NUM_ALL = NUM_BATCHES * NUM_ITERS_PER_BATCH  # i.e., no user-side "batching"

executor = JITEx(batchsize = NUM_ITERS_PER_BATCH)
Folds.foreach(calc_with_x, Iterators.map(get_some_x, 1:NUM_ALL), executor)

although, of course, JITEx does not exist yet. I think it’d be nice to have an interface such that batching is completely decoupled from the actual algorithms/science/“business logic.” The foldl function provides an abstraction required.


  1. In my documentation and code, I usually call it type-stabilization or something, though. It’s too primitive compared to what I think Catwalk.jl is doing. Also, my foldl is a stone-age tech compared to, say, V8 or JVM. ↩︎

I start to understand your function-barrier based accumulation method, and I like its elegance, especially that it works without @generated functions. However (I am not sure I got it right), it seems that the Val-typed stack-limiting counter in __foldl_iter prevents reusing previously compiled code when an old type comes back again. Is that true?

That would be great to see! In the context of transducers, this looks very natural and nice.

It was a hard decision regarding the API of Catwalk, I had something like this in my mind, but I ended up deciding to not require the user to split the problematic part of the code into two functions, and hacking the body of the jit-ed function instead. In retrospect I am almost sure this wasn’t the right decision.

1 Like

Yes, you are absolutely right! That’s why I said it was a weaker version. On the other hand, the strategy I used in FoldsThreads.trampoline_stabilizing tracks all the history of types and do the manual union splitting and I think it’s somewhat closer to Catwalk.jl. But it’s still kinda dumb since there are no clever things like frequency analysis. I use the weak version by default since it’s easy on the compiler and will be compiled away if it can prove there is no type instability (which makes supporting GPU easier).

1 Like

So, I tried hooking Catwalk into JuliaFolds to make sure my guess about the API was correct. The integration actually wasn’t so hard. I tried a couple of examples but I’m still not able to demonstrate the performance property of Catwalk. I guess I need a bit more analysis.

2 Likes

Nice!

I was able to fix the performance issue (Only tested with the sum example):

https://github.com/JuliaFolds/FoldsCatwalk.jl/pull/1

3 Likes

Great! Thanks a lot for looking into this.

Suppose f and g are pure functions, and I want to compute h(x) = f(x) && g(x). && is short-circuiting, and because f and g are pure, the only difference between h and h'(x) = g(x) && f(x) is performance, and that whether h or h' is faster at a call site (or every call site) can depend on the workload at runtime.

Could Catwalk address this kind of adaptive optimization?

1 Like

Interesting idea!

It seems possible to do that within the current concept without huge changes to the code, but short-circuit profiling has very different cost characteristics than dispatch profiling, as both sides must be evaluated and timed, but maybe a few profiled execution could be enough to decide. Currently profiled batches are selected automatically thus typically have the same length as non-profiled ones, and that may not work well here.

Can you please share your use case? What are typical execution times of f and g, and how fast can they change during runtime?

1 Like

Sorry for the late follow-up. I don’t have specific examples with execution times, but roughly what I had in mind is checking whether a point is in a geometric set: in(p , S). To speed up that query, S may have a simple over approximation (bounding volume) and simple under approximation, and the query can be checked with both of those to either return true or false without checking the shape itself. A sketch of the logic is:

function in_fast(p, S)
    if in(p, underapproximate(S)); return true; end
    if !in(p, overapproximate(S)); return false; end
    [... actually check against S itself ...]
end

which could be written in one expression in terms of && and ||.

There are many situations with the pattern of checking easy or common cases before doing the complicated, expensive check.

(The other situations that comes to mind along these lines are 1. eager vs. lazy computation, and 2. caching a result vs recomputing it as needed. I don’t know how to express the possibilities succinctly in these cases, though.)

1 Like