Would it be possible to implement Jax-style reverse mode AD in Julia?

I’ve spent a fair bit of time moving between Jax and Julia over the last year. One thing that has stuck out to me is that the AD libraries in Julia work in a fundamentally different way than Jax. As far as I can tell, the Julia AD libraries are pretty strongly tied to the reverse-mode interpretation based on cotangent vectors, and ChainRules.jl gives the explicit reverse-derivative operations for each map. Conversely, Jax takes a different approach, by decomposing reverse-mode AD into two steps, defining a forward-mode jvp operation, and then an abstract transpose operation to construct a vjp operation (I will admit, I have my own reasons for thinking this is a more desirable approach).

I can’t see any reason why ForwardDiff.jl couldn’t support Jax’s forward-mode API, as it’s easy enough to define a jvp operation externally (although I haven’t sorted out a way to extend it to a jacobian-matrix product without incurring a bunch of allocations), and it seems as though the transpose operations are easier to implement and optimize (but that may be XLA magic on Jax’s side). It also seems like this approach makes for easier nested differentiation. I would also imagine Jax has a much smaller set of primitive operations, which may be the big issue.

Note: I don’t mean this as a criticism, I think Zygote, ReverseDiff, Diffractor are all great projects. I’m just trying to sort out if anyone knows a fundamental reason why Jax’s AD model can’t be ported over to Julia.

2 Likes

It totally would be possible, just a lot of work. A lot of the difference between the Julia and JAX approaches to AD is that the Julia ones have tended to be designed to be implementable by a few devs part time while JAX is an approach that requires a well funded team.

1 Like

Does that extend to the AD system itself? Per the paper I linked, part of the benefit of decomposing AD into a forward mode pass and a transposition is to reduce the amount of developer work:

It is tempting to try to quantify the cost savings due to a smaller rule set. In JAX the number of transposition rules is roughly 40% of the JVP rule count.

So running a partially-evaluated jvp(f,x,-) through something like Cassete.jl seems like it would get you some fraction of the way there.

JAX’s AD relies on a source code transform, so I’m not sure how you’d replicate that with an overloading-based AD like ForwardDiff. That’s not to say it can’t be one, of course. GitHub - dfdx/Yota.jl: Reverse-mode automatic differentiation in Julia does a similar trace-to-tape method for reverse mode, and GitHub - pabloferz/CodeShifter.jl is an experiment in porting the transposition transform to Julia (among other goodies).

Thus the tricky part is not implementing what’s written down on paper, but making it usable for real-world programs. The tradeoff JAX makes here is to assume a closed universe of operations and complete immutability of values, then plug everything into a giant array/linear algebra optimizer and hope a fast program comes out the other side. As you can imagine, the Julia community doesn’t have Google-level resources to do something similar, and even if it did there’s the question of whether the model of immutable + vectorized everything makes sense for this language. After all, part of the value proposition of Julia is that you can write fast loops, code which has predictable performance and extend existing libraries as you desire.

8 Likes

Yes, it’s certainly possible to implement JAX-like AD in Julia, but the scope may be different from the described one. Let’s break it down.

  1. Getting a linear representation of a function call, i.e. what you get in JAX using make_jaxpr, can be already using Umlaut.jl. Note this representation implies that you only have a closed set of primitives (aka base operations) in the jaxpr / tape. It’s a great representation since you can stack transformations on it, e.g. apply vjp on top of jvp on top of vmap etc. and get a valid representation on each step. JAX has ~200 such primitives.
  2. You can find implementations of JVPs / VJPs in ChainRules’s frule / rrule, but these implementations use tools beyond the closed set of primitives. Or, if you wish, its set of primitives is the whole Julia language, including structs, closures, mutation, type system, control flow, etc. Writing transformations that can handle all of these is no fun. Zygote, Diffractor and Enzyme do some great job bring AD on that low level, but it requires a lot of work of very smart people. If you don’t want to deal with all these details and you are fine with the closed set of primitives, you’ll have to adapt ChainRules to this set.
  3. I never touched transposition and generating VJPs from JVPs since I’ve never been interested with the forward mode AD, but from the context of the paper it doesn’t seem to be a hard problem. Again, assuming a limited set of primitives.
  4. Since we touched XLA in this thread, I don’t know an easy way to use it in Julia. XLA has almost no documentation, no official external API and doesn’t seem very Julia-friendly. However, there’s ONNX which is similar and we already work on it. Targeting a format like ONNX is great not only because you can get an optimized computational graph, but also for exchanging pretrained models. Yet, ONNX contains ~250 operators, so again, not every Julia program can be transformed into it.

So, if you don’t want to deal with all peculiarities of the Julia language and their mapping to external systems, you have to define a set of primitives. Recently I started Remix.jl - an experiment to create such a set. It’s on a very early stage, but contains a working prototype of a reverse-mode AD with JAX-like API. I’m not sure about the future of this experiment, but feel free to play around with it.

6 Likes

As @ToucheSir mentioned, it should be possible to build such an AD system in Julia, and I have long thought about exactly implementing AD as a composition of either forward or reverse and transpose. CodeShifter.jl is my own experiment on trying to achieve this in Julia. Unfortunately, I haven’t have the time to make as much progress as I’d like.

So far I have the equivalent to make_jaxpr, called code_shifted in CodeShifter.jl, which is a linear representation of a function call as @dfdx mentioned, and a barebones JVP implementation with its own set of rules and a similar interface to jvp in JAX. From what I can tell so far implementing transpose rules will indeed be less work than implementing the reverse rules for whatever set of primitives I decide to support in the end.

One point I would reiterate, is that on of the main issues here is lack of resources for having a team dedicated to work on Julia’s AD ecosystem.

6 Likes