Open discussion on the state of differentiable physics in Julia

, ,

I knew you would ask because you’re hitting THE pain point I know about right now. From the Universal Differential Equations paper supplement it describes the 4 vjp choices:

  • Zygote.jl source-to-source transformation based vjps. Note that only non-mutating differential equation function definitions are supported in this mode. This mode is the most efficient in the presence of neural networks.
  • Enzyme.jl source-to-source transformation basd vjps. This is the fastest vjp choice in the presence of heavy scalar operations like in chemical reaction networks, but is currently not compatible with garbage collection and thus requires non-allocating f functions. (Note: since that was last updated, Enzyme got support for a subset of garbage collection)
  • ReverseDiff.jl tape-based vjps. This allows for JIT-compilation of the tape for accelerated computation. This is a the fast vjp choice in the presence of heavy scalar operations like in chemical reaction networks but more general in application than Enzyme. It is not compatible with GPU acceleration.
  • Tracker.jl with arrays of tracked real values is utilized on mutating functions

and the AbstractDifferentiation.jl provides a similar perspective.

Then you also have the ForwardDiff.jl AD system, which is a scalar forward mode which you can think of like compiled ReveseDiff.jl. Now here’s the way to understand this in practice. If you’re using DifferentialEquations.jl time stepping, the job for the ODE solver is a lot easier since it only has to differentiate your ODE function. That said, here’s the process of choosing the right vjp:

  • If you tend to have a bunch of linear algebra, Zygote.jl works really well. But Zygote.jl cannot handle mutation. This means linear algebra code goes here, and GPU codes go here, but other codes need to find a different solution.
  • Enzyme handles mutation, but it doesn’t have general support for higher level Julia functionality (i.e. it only has partial support for generic and untyped code. It works really well on type-stable and “almost static” code), an many times ChainRules.jl over some function might be better than just naively differentiating through the algorithm. It also has a lot of edge cases for non-static and allocating code. This means a lot of SciML codes go here, but it can hit an edge case of unsupported behavior, which 99% of the time seems to be BLAS (i.e. linear algebra).
  • You can sometimes work around Enzyme issues by going to ReverseDiff, but this is very rare and it’s only fast if the code is non-branching (i.e. has no if statements).
  • If you need to differentiate through a mostly non-mutating code that has a few mutations (like a few pop!s and such) that break Zygote, and it’s on the GPU, then Tracker can be the right solution.

So that’s a big overview, but I think there’s a better way to understand it. Other than the hacks, there are two worlds. There is this world:

function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

Yum, all linear algebra, Zygote eats that up. Works so well. (Source: Neural ODEs with DiffEqFlux · ComponentArrays.jl)

And there is this world:

const N = 32
const xyd_brusselator = range(0,stop=1,length=N)
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
limit(a, N) = a == N+1 ? 1 : a == 0 ? N : a
function brusselator_2d_loop(du, u, p, t)
  A, B, alpha, dx = p
  alpha = alpha/dx^2
  @inbounds for I in CartesianIndices((N, N))
    i, j = Tuple(I)
    x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
    ip1, im1, jp1, jm1 = limit(i+1, N), limit(i-1, N), limit(j+1, N), limit(j-1, N)
    du[i,j,1] = alpha*(u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4u[i,j,1]) +
                B + u[i,j,1]^2*u[i,j,2] - (A + 1)*u[i,j,1] + brusselator_f(x, y, t)
    du[i,j,2] = alpha*(u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4u[i,j,2]) +
                A*u[i,j,1] - u[i,j,1]^2*u[i,j,2]
    end
end
p = (3.4, 1., 10., step(xyd_brusselator))

Look at that non-mutating fully non-allocating beast of an ODE function. Enzyme smashes on this (source on benchmarks: [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions).

The problem is the middle. Stick a little bit of mutation into the Zygote code, or stick a linear algebra operation into the second code, and both AD systems will fail. That’s the best way to describe why your example (Optimizing performance of 2D nonlinear diffusion UDE - #29 by JordiBolibar) is so difficult, and you’re not the only one to run into that. I manually worked around this for the UDE paper in the PDE examples by fully vectorizing and allocating a bit more than it should so that Zygote would work. That is a major pain and development hurdle.

So then, what can we do about this? That’s precisely what the DJ4Earth project is all about. How do we not force people to change code in order to allow differentiation to work? It’s a two pronged approach:

  • On the Enzyme side, it needs to add support for as much higher level functionality as it can. Now, it is applying its transformations after Julia is lowering to LLVM, so it won’t have all of the information that Zygote/Diffractor has, but it can do a lot better. For example, supporting BLAS calls would probably fix 99% of cases. And so we have people working on that.
  • On the Zygote side, it needs to at least handle mutation. Now, it doesn’t have all of the context of Enzyme, so it won’t be able to as easily generate efficient mutation code, but it can support it better and use the growing set of Julia code analysis tools to remove generated allocations. Adding slow mutation support + improving the speed by using the compiler plugin interface would probably fix 99% of cases. And so we have people working on that, where the codegen speedup is the change to Diffractor.jl and there are known ways to support mutation (which would require compiler plugins to not be terribly slow).

Thus in the end Enzyme will grow to support more array codes and Zygote will change to Diffractor and grow to support higher order AD and more mutation over time.

What does this mean to a SciML user? The codes that you generally see inside of ODE rhs definitions is usually much more constrained than general Julia code, and it almost falls perfectly into the domain of Enzyme. I think for most SciML applications like UDEs, an improved Enzyme will make people happy and will generally be the default VJP choice. However, for differentiation on the outside, like the differentiation call that is calling the ODE solver, it will likely change from Zygote to Diffractor, and higher order AD cases will work much more nicely from that. Enough mutation support will be available (not now but in the further future of Diffractor) that you could differentiate the solver, but the adjoints + Enzyme will likely be the more efficient option. The SciML codes already know how to automatically mix the AD codes and run compiler analyses of the ODE to choose the AD mix, so there won’t be any user issues around this, the only reason why it currently fails is that it can hit these scenarios where no AD is effective.

That should explain how so much is easy but why some codes feel so hard to differentiate right now, but at the same time how an effective solution is in sight.

46 Likes