Using Julia autodiff code from python with JAX

I think your best bet combines Reactant.jl with Enzyme.jl. I’m not sure exactly how to combine both to do what you want, but @wsmoses will know.
You may also be interested in past discussion on the topic: