Open discussion on the state of differentiable physics in Julia

, ,

I have been thinking about this for some time, as a result of my personal experience starting to work with Julia this year. I chose Julia for my new research project on Universal Differential Equations for glacier evolution modelling, but also after seeing so many cool projects (e.g. CliMA or DJ4Earth) related to differentiable programming for earth sciences.

My experience with Julia and its community has been overwhelmingly positive. I quickly got up to speed with the language, the community is very welcoming and it’s full of cutting edge projects. Nonetheless, I noticed that some SciML packages are extremely mature and stable (e.g. DifferentialEquations.jl), but others, particularly related to automatic differentiation, are more new and experimental. It is particularly the AD part of the SciML ecosystem that has presented most of the difficulties in my work in the last months. I’m perfectly aware that this is likely due to the huge technical challenge that implies to develop these packages, so huge respect and kudos to all the developers.

With this post, I just would like to get a general sense of what is the current state of differentiable programming applied to physics (and maybe particularly to earth sciences) in Julia. What is currently being done out there? What are the main challenges and difficulties encountered by these different projects? What has worked and what has not worked for you? What is it possible to do with differentiable physics right now in Julia? What needs more time to work smoothly?

15 Likes

I knew you would ask because you’re hitting THE pain point I know about right now. From the Universal Differential Equations paper supplement it describes the 4 vjp choices:

  • Zygote.jl source-to-source transformation based vjps. Note that only non-mutating differential equation function definitions are supported in this mode. This mode is the most efficient in the presence of neural networks.
  • Enzyme.jl source-to-source transformation basd vjps. This is the fastest vjp choice in the presence of heavy scalar operations like in chemical reaction networks, but is currently not compatible with garbage collection and thus requires non-allocating f functions. (Note: since that was last updated, Enzyme got support for a subset of garbage collection)
  • ReverseDiff.jl tape-based vjps. This allows for JIT-compilation of the tape for accelerated computation. This is a the fast vjp choice in the presence of heavy scalar operations like in chemical reaction networks but more general in application than Enzyme. It is not compatible with GPU acceleration.
  • Tracker.jl with arrays of tracked real values is utilized on mutating functions

and the AbstractDifferentiation.jl provides a similar perspective.

Then you also have the ForwardDiff.jl AD system, which is a scalar forward mode which you can think of like compiled ReveseDiff.jl. Now here’s the way to understand this in practice. If you’re using DifferentialEquations.jl time stepping, the job for the ODE solver is a lot easier since it only has to differentiate your ODE function. That said, here’s the process of choosing the right vjp:

  • If you tend to have a bunch of linear algebra, Zygote.jl works really well. But Zygote.jl cannot handle mutation. This means linear algebra code goes here, and GPU codes go here, but other codes need to find a different solution.
  • Enzyme handles mutation, but it doesn’t have general support for higher level Julia functionality (i.e. it only has partial support for generic and untyped code. It works really well on type-stable and “almost static” code), an many times ChainRules.jl over some function might be better than just naively differentiating through the algorithm. It also has a lot of edge cases for non-static and allocating code. This means a lot of SciML codes go here, but it can hit an edge case of unsupported behavior, which 99% of the time seems to be BLAS (i.e. linear algebra).
  • You can sometimes work around Enzyme issues by going to ReverseDiff, but this is very rare and it’s only fast if the code is non-branching (i.e. has no if statements).
  • If you need to differentiate through a mostly non-mutating code that has a few mutations (like a few pop!s and such) that break Zygote, and it’s on the GPU, then Tracker can be the right solution.

So that’s a big overview, but I think there’s a better way to understand it. Other than the hacks, there are two worlds. There is this world:

function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

Yum, all linear algebra, Zygote eats that up. Works so well. (Source: Neural ODEs with DiffEqFlux · ComponentArrays.jl)

And there is this world:

const N = 32
const xyd_brusselator = range(0,stop=1,length=N)
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
limit(a, N) = a == N+1 ? 1 : a == 0 ? N : a
function brusselator_2d_loop(du, u, p, t)
  A, B, alpha, dx = p
  alpha = alpha/dx^2
  @inbounds for I in CartesianIndices((N, N))
    i, j = Tuple(I)
    x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
    ip1, im1, jp1, jm1 = limit(i+1, N), limit(i-1, N), limit(j+1, N), limit(j-1, N)
    du[i,j,1] = alpha*(u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4u[i,j,1]) +
                B + u[i,j,1]^2*u[i,j,2] - (A + 1)*u[i,j,1] + brusselator_f(x, y, t)
    du[i,j,2] = alpha*(u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4u[i,j,2]) +
                A*u[i,j,1] - u[i,j,1]^2*u[i,j,2]
    end
end
p = (3.4, 1., 10., step(xyd_brusselator))

Look at that non-mutating fully non-allocating beast of an ODE function. Enzyme smashes on this (source on benchmarks: [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions).

The problem is the middle. Stick a little bit of mutation into the Zygote code, or stick a linear algebra operation into the second code, and both AD systems will fail. That’s the best way to describe why your example (Optimizing performance of 2D nonlinear diffusion UDE - #29 by JordiBolibar) is so difficult, and you’re not the only one to run into that. I manually worked around this for the UDE paper in the PDE examples by fully vectorizing and allocating a bit more than it should so that Zygote would work. That is a major pain and development hurdle.

So then, what can we do about this? That’s precisely what the DJ4Earth project is all about. How do we not force people to change code in order to allow differentiation to work? It’s a two pronged approach:

  • On the Enzyme side, it needs to add support for as much higher level functionality as it can. Now, it is applying its transformations after Julia is lowering to LLVM, so it won’t have all of the information that Zygote/Diffractor has, but it can do a lot better. For example, supporting BLAS calls would probably fix 99% of cases. And so we have people working on that.
  • On the Zygote side, it needs to at least handle mutation. Now, it doesn’t have all of the context of Enzyme, so it won’t be able to as easily generate efficient mutation code, but it can support it better and use the growing set of Julia code analysis tools to remove generated allocations. Adding slow mutation support + improving the speed by using the compiler plugin interface would probably fix 99% of cases. And so we have people working on that, where the codegen speedup is the change to Diffractor.jl and there are known ways to support mutation (which would require compiler plugins to not be terribly slow).

Thus in the end Enzyme will grow to support more array codes and Zygote will change to Diffractor and grow to support higher order AD and more mutation over time.

What does this mean to a SciML user? The codes that you generally see inside of ODE rhs definitions is usually much more constrained than general Julia code, and it almost falls perfectly into the domain of Enzyme. I think for most SciML applications like UDEs, an improved Enzyme will make people happy and will generally be the default VJP choice. However, for differentiation on the outside, like the differentiation call that is calling the ODE solver, it will likely change from Zygote to Diffractor, and higher order AD cases will work much more nicely from that. Enough mutation support will be available (not now but in the further future of Diffractor) that you could differentiate the solver, but the adjoints + Enzyme will likely be the more efficient option. The SciML codes already know how to automatically mix the AD codes and run compiler analyses of the ODE to choose the AD mix, so there won’t be any user issues around this, the only reason why it currently fails is that it can hit these scenarios where no AD is effective.

That should explain how so much is easy but why some codes feel so hard to differentiate right now, but at the same time how an effective solution is in sight.

46 Likes

where Diffractor fits ?

It’s not too much of a simplification to just call Diffractor a Zygote that compiles faster and can optimize code better.

3 Likes

It definitely isn’t there yet, but it might theoretically be able to handle scalars well in the future?

Enzyme is good at performance of mutation and such because it has access to all of the analysis passes of LLVM. In order for Diffractor to get similar performance as Enzyme, it would need to have escape analysis and all of that in order to delete the allocations. One suggestion for Diffractor to support mutation well would be to complete https://github.com/JuliaLang/julia/pull/42465, then turn mutation operations into operations on immutable operations that allocate copies, and then let the suggested copy deletion pass of that PR remove the unnecessary copies. That said, that pass to remove the unnecessary copies is really far away, and it needs Shuhei’s GitHub - aviatesk/EscapeAnalysis.jl: analyze escape information in Julia IR.

1 Like

Thanks a lot Chris for such a complete answer. This summarizes and clarifies many elements of the discussions we’ve been having recently, and offers a clear overview of it all.

I must say I’m also kind of relieved to know that the problem I’m working on is one of the hardest one can tackle with AD right now. This explains why it has been so hard to make progress these last months.

Now, I must ask: what is the best strategy for these intermediate hard cases mixing linear algebra and mutation right now? From your analysis I see two potential ways out (but correct me if I’m wrong):

  1. Using Zygote, I know that in practice it can work for my problem, since I already have a manual implementation of it. The issue there is the fact that in order to avoid mutations one needs a ton of buffers and allocations, making any long forward run extremely memory costly to differentiate. A potential way to compensate that would be using a very efficient solver in order to minimize time stepping. Would that be enough to apply Zygote for such cases or one would still be limited to very short simulations (i.e. limited number of ODEs)? Otherwise, would there be any other way to optimize the code in order to use Zygote for such a problem?

  2. Using Enzyme, I’ve encountered exactly the issue you mentioned, a BLAS call which is not supported. Enzyme seems like a perfect solution in order to avoid any memory allocation, but given the frequent use of linear algebra in this sort of physical problems, it also seems pretty daunting. Would it be feasible, for my problem, to use ChainRules for some functions in order to make it work, or that wouldn’t help at all with the lack of support for linear algebra?

I need to find some compromise between these two methods, even if it’s not ideal performance-wise. My goal would be to immediately have something functional to start exploring the application of UDEs in my field, and as soon as one of the two packages fix one of these main issues, switch to that option in order to optimize the code. My impression is that some models in the DJ4Earth project (the ice flow model for sure), will have very similar requirements.

1 Like

1 is probably easiest. Take the performance hit of non-mutation and go with it. As your problem size increases, if you’re using an implicit method, progressively more time will be spent in the factorization and matrix multiplications (which grows as O(n^3)) and thus the allocations won’t matter after awhile. The hard place of course are the “midsized” problems for which the asymptotic behavior is not a good enough reason to leave off performance tricks. Still, it’s what I’d recommend today.

For 2, you cannot just use ChainRules with Enzyme easily. That’s kind of the whole problem there: when the code gets down to the LLVM level where Enzyme acts, there’s no guarantee that your function calls even exist anymore. Those calls may have been inlined by Julia, and so there would be no way to intercept it in that case. Fancy tricks can probably be used to get a lot of cases (and it has to get fancy even just to support allocations in the first place, or any dynamism), but I wouldn’t expect an average user to hack on that at all. Instead, BLAS support should be coming soon (it’s actively being worked on by some of the Julia Lab students), and that should be most of what’s needed in most cases.

5 Likes

For anyone stumbling on this thread and still eager to read more about this, @ChrisRackauckas wrote an excellent post on the current (and future) landscape of AD tools in Julia, Python and other languages, which nicely complements this discussion.

13 Likes

Thanks to all of you for sharing your experiences!

I recently learned about the flexibility of Julia’s AD libraries and now I’m very intrigued. I intend to use Enzyme to implement some algorithms for my PhD project and am not sure if the tools are ready for my requirements yet, so now I’m asking the experts. :slight_smile:

Essentially I will need to build a sparse matrix, apply it to a vector and then calculate a scalar loss function from the resulting vector. Building the sparse matrix is expensive, but the sub-algorithms are parallelisable using GPU kernels. I know that Enzyme can perform the AD for GPU kernels, but now I’m struggling with chaining multiple kernel calls and combining it with Zygote operations to calculate the loss function.

Do you think such a project can be implemented with the tools at hand in an efficient way?

1 Like