Not sure I understand all the subtlety of your task, but here are some decisions from Yota which seems to be somewhat similar.
First of all, overhead of Base.invokelatest()
might not be a big problem if you work with anything but very trivial operations. As an example, consider trivial in-place multiplication of 2 100x100 matrices X
and Y
:
julia> @btime mul!(Z, X, Y);
27.355 μs (0 allocations: 0 bytes)
julia> @btime Base.invokelatest(mul!, Z, X, Y);
27.349 μs (1 allocation: 32 bytes)
Base.invokelatest
added so little overhead that benchmark couldn’t even catch the difference. It’s also usually much faster than allocating memory for buffers or interpreting graphs. You still can use artificial hacks to avoid dynamic dispatch altogether by compiling your function before defining another one that uses it, but I really recommend to benchmark invokelatest
in your settings thoroughly before making more complex decisions.
Second, from your description I understand you want to compile a new function and then differentiate it using Zygote. In Yota I use the opposite approach:
-
Yota.grad(f, args...)
traces execution of functionf
, creating an unrolled tape, similar to what you described. - Differentiation is done on this tape, gradient operations are recorded onto it. You end up with a single tape doing both - forward (evaluation) and reverse (differentiation) pass.
- The tape is compiled and cached.
Base.invokelatest()
is used to call the compiled function even if it’s defined in later world than the current one. Since compiled tape already contains instructions for calculating gradients, there’s no need to define adjoints forBase.invokelatest()
.
A few words about performance. Yota uses Cassette (or JuliaInterpreter) to trace the function call, which is pretty slow, so the first call may take from a few seconds to a few minutes. I don’t follow Zygote recently, but in previous versions it used similar approach and may be a subject to the same issue. The following executions of compiled code, however, are much faster (usually as fast as hand-written code).
Also I recommend to optimize tape/graph before compilation, e.g. eliminate common subexpressions, remove unused nodes (e.g. produced during tracing, but not needed in an unrolled graph), use pre-allocated buffers and use in-place operations where possible - in my experience difference between optimized and unoptimized graph can be as large as 100x times.
Hope this is something useful for your package.