I have a pretty basic question. I am a scientist who is new to Julia (hence the category I am posting this in), but interested in learning more both out of personal curiosity as well as the possible benefits for my own work.
There is one step in a longer analysis pipeline in my field that would benefit from autodiff while not sacrificing speed. This step solves a system of differential equations to compute the theoretical predictions of a model, which is then compared against data in, e.g., MCMC. Other elements of this pipeline have autodiff implementations written by others in the community, but predominantly using python with JAX.
I am interested in plugging this hole in the pipeline by writing a Julia package that computes these theoretical predictions with autodiff. I am interested in Julia above using python with JAX for a few reasons.
Performance is important, as a single computation of the theoretical quantities mentioned above, in the modern C and/or Fortran codes widely used, can take up to O(minutes) for execution. Given the number of such computations needed for an MCMC (for example), it is important that the code is fast. My suspicion is that it would be much harder to write python code that is as performant, compared to Julia.
Ease of autodiff. It is probably possible to write code with autodiff that accomplishes this task in Python with JAX, but my impression is that it would be tedious to do while balancing clarity and performance. I think this is why no one in the community has really done this yet.
Clarity. The module I plan to write, in the best case scenario, would be easily extensible by others looking to use the code, say by adding terms to the system of equations that gets solved. I think, with well structured code, this should be easy even for new Julia users.
Given these motivations, I think Julia is a good option for me. But, since most of the pipeline is done with JAX, I am wondering if my idea is even possible to begin with. If I have a Julia module with autodiff implemented, can I benefit from that autodiff in a python pipeline that uses JAX, say by using the pyjulia package? Are there other considerations here that I am not aware of?
I think your best bet combines Reactant.jl with Enzyme.jl. I’m not sure exactly how to combine both to do what you want, but @wsmoses will know.
You may also be interested in past discussion on the topic:
Hello, just to make sure, do you want Jax to autodiff the computation made in Julia ? In which case I don’t think it’s possible without generating the MLIR in Julia ( with Reactant.jl for instance), however if you only want the forward pass to be in Julia I’m pretty sure it’s fine but you will want to reduce as much as possible the communication Julia Array → python array → Jax array which may be bad if done too much.
I guess you should try it and if something is really bothering you ask it here.
Also, Jax won’t be able to fuze as much as it does in python if you’re going through Julia it won’t even compile it well so you will loose the XPU advantage which you may have a hard time to rebuild
Thanks for the pointer to Reactant.jl, definitely seems useful. I was already familiar with Enzyme.jl and currently plan to use it if I go forward with the project.
Thanks as well for the link to the other discussion. I had come across that thread in my searches prior to posting, and it’s definitely useful for me, but my problem is sort of the converse of that (I want to call Julia autodiff code from a predominantly JAX pipeline). Since I’m relatively new to the language and ecosystem it wasn’t clear to me that what I’m after was possible.
I don’t think that’s exactly what I’m after, but I may be misunderstanding you. For simplicity you can picture the pipeline I’m talking about as having two steps.
where the first step is the theoretical prediction I want to implement in Julia, and the second step takes that prediction and passes it to a likelihood function written with JAX. My question is: is there a way to do this such that the whole pipeline is autodifferentiable (say, by calling the Julia code that gives you \mathcal{T} from python). I would also be interested in a setup where, e.g., I call the \mathcal{L} code from Julia to do the same.
Yeah Enzyme-JaX will let you do exactly that converse (call Julia from JaX, see the lux docs on it above). the readme for EnzymeJaX has an example of doing c++ from jax (and we don’t have a native julia as a string form, the setup from julia into jax requires the export – but if folks are interested we can add that feature too). Note that even with calling jax.grad on the outside, using EnzymeJaX to import the code will use Enzyme(MLIR) to autodiff it while compatible with JaX’s existing primitives, since jax doesn’t actually support autodiffing non-python code.
Thanks all, great to know this is generally possible and I really appreciate the links to concrete ways to get started. Enzyme-JAX seems like it might be the way to go for me.