Plugging DifferentiationInterface.jl into pytorch

I preface this with saying that I really am not very experienced with both pytorch and juliatorch, and apologize for any misunderstandings that I have.

We are developing a physics simulation ecosystem that can be differentiated using almost any of the AD backends in DifferentiationInterface.jl (our interface is mutating with scalar indexing on the CPU so e.g. Zygote.jl doesn’t really work). Many of our collaborators/prospective users, however, are sticking with Python no matter what, given their current workflow setups, and will not use/learn Julia. No problem; we are ensuring that our Julia ecosystem is easily usable from within Python - much of this has been made super easy thanks to the juliacall package.

One of the most important things for our software is to have easy integration with pytorch. It seems to me that the best way of handling this is following the steps here to plug our Julia functions into the autograd framework of pytorch: Extending PyTorch — PyTorch 2.8 documentation

To my understanding, the juliatorch package does exactly this in a generic way, so that Julia functions can be differentiated through using the autograd “backend” in pytorch. This may be what we end up using in the long term.

However, I imagine there could be a lot of performance gains by computing the gradients directly in Julia, and then passing said gradients to pytorch. Really what I would like is to compute the gradients using DifferentiationInterface.jl. For example, in Julia I can then specify AutoReverseDiff(;compile=true) and get super fast backwards gradients to then pass to pytorch. I worry that differentiating using pytorch’s tensors directly may not achieve similar performance, especially given our mutating interface.

If there indeed are major performance gains in doing this, then this tooling could find broad use in the wider community.

From what I see in its source code, juliatorch does indeed plug into PyTorch’s extension mechanism by computing gradients in Julia (with ForwardDiff) and passing them back to Python. It wouldn’t be very difficult to replace ForwardDiff with another backend like Enzyme or Mooncake. If you want to code this stuff generically with DI, the only thing missing would be a splittable pullback API, which is not that much work for me to implement but no one asked before.

However, converting arrays back and forth between languages can be costly if they can’t reuse the same memory. I’m not sure what the current best approach for that is, and you might need to read up on things like DLPack.jl. Another important question is whether you need your code to run on GPU or only CPU.

@avikpal I know that Reactant.jl plays nicely with JAX, how about PyTorch?

1 Like

Thanks for your very detailed reply! This link you posted to the source code is great, it doesn’t seem very challenging to rewrite this more generically using DI. Something then to work out would be how to select a backend elegantly from within Python

Indeed we would want it to work on the GPU too. Currently it works well with regular numbers thanks to KernelAbstractions.jl, however I haven’t actually tested it with differentiation on the GPU yet. I will say Enzyme.jl, the compiler throws a weird error that I haven’t had time to work into a MWE. Maybe related to one level of type instability that we have

if you have a reproducer, even if not minimal, I can take a look?

1 Like

Sure! I’ll submit an issue with it soon… couple of development things we need to clean up and release first.

1 Like

There’s GitHub - llvm/torch-mlir: The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem. (/GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)), so we could in-principle do the same trick as Jax. Compile to fx which is then lowered to torch-mlir and then we feed it into reactant via hlo_call..