For the third time, I am trying to implement some wave optics algorithms in Julia but each time I start, I suffer a bit from the current automatic differentiation status.
In principle, my functions look like:
function propagate(field::AbstractArray{T, 6}, distance::NumberOrArray,
wavelengths::NumberOrArray, pixel_size)
H = # some kernel calculated with distances and wavelengths
field_p = # some processing on the field with wavelengths and distance
field_prop = ifft(fft(field_p) .* H) # roughly
return field_prop
end
The adjoint is a bit nasty to write manually since distance or wavelengths can be both scalars or vectors. The output will be 6 dimensional ((x,y,z, wavelengths, polarization, batch)), but some dimensions are singleton depending on the types of distance and wavelengths.
Anyway, it seems like Enzyme.jl is not supporting FFTs at the moment (latest is this issue or older this). So in my current code I use a mix of Zygote and custom written ChainRules.jl.
I’d love to use DifferentiationInterface.jl but if FFTs do not work, that’s gonna be a hard time.
Has anyone an idea what I should do? I feel like using Zygote.jl is not quite future proof as everyone else seem to switch to Enzyme.jl.
Honestly, I seriously consider to implement it in another language because this pain has been ongoing for years (I started hitting those general issues in 2020). I know there is progress on this front and things have been improved but they haven’t reached my feature demands yet (CUDA + FFT + complex numbers). And implementing an Enzyme rule seems just very hard for me. I would love to help out a bit but right now I want to focus on my real research problems.
I see similar packages implemented in PyTorch and JAX and it works well for them, so I am really a bit hopeless and cannot recommend Julia to them right now.
I just want to point out that the question of using DifferentiationInterface is pretty much orthogonal to the choice of backend. At least if you fit inside the restrictions of DI (a single active argument), you can choose between Enzyme, Zygote or anything else you fancy. Whether FFTs work or not depends on the existence of rules for a given backend, not on the DI infrastructure, see this docs page for details.
In fact, using DI can actually make your code more resilient and allow easy switches between autodiff systems down the road. One of the main motivations for developing DI was to encourage this kind of seamless transition, e.g. from Zygote to Enzyme. Of course it comes with a caveat that DI induces some overhead and/or bugs which can be avoided with a backend’s native API. If you run into something like that, please open an issue so we can work to fix it!
Have you tried implementing the Enzyme FFT rule yourself? Maybe it’s not out of reach?
Alternately, have you tried Mooncake.jl? Even if it doesn’t have FFT rules either, its rule system is less complex than that of Enzyme, so it may be an easier starting point.
EDIT: I don’t think Mooncake fully supports CUDA at the moment, so it doesn’t answer your full question. Sorry.
Honestly, that’s fair. If it ain’t broke, don’t fix it. It would be great to get this in Julia too, but if you need it for your research and Python has it, no one will blame you ^^
Reactant.jl supports FFT (and its differentiation), used quite extensively in NeuralOperators.jl (it also runs on CUDA, TPU, CPU and whatever other accelerator you might want).
Is your field_p a 6-dimensional tensor? (we added upto 3-dims for FFTs in reactant though it is easy to expand if that is what is needed)
So Reactant uses EnzymeMLIR, in-contract to EnzymeLLVM both interfaced via Enzyme.jl (so no change is required on user side except moving data via Reactant.to_rarray and doing a @compile). We have the FFT rules defined for EnzymeMLIR.
Not quite correct, but an analogy is that Reactant is similar to jax.jit, and Enzyme is similar to jax.gradient. The difference here is that if you Reactant.@compile a function, it converts your julia code into a nice tensor program that can be automatically optimized/fused/etc and executed on whatever backend you want, including CPU, GPU, TPU, and distributed versions thereof. Essentially all programs in this tensor IR are differentiable by Enzyme(MLIR), so if it reactant.compile’s, you’re good!
As for Enzyme.jl without Reactant compilation, there’s no technical reason for it – I think someone just needs to write the rule and/or use Enzyme.import_rrule to import the chainrule like described in the issue you linked above. Maybe you, @ptiede or someone are interested in getting that over the line?
That said I’d recommend just using Reactant.@compile, with Enzyme on the inside. It’ll give you all the things you say are missing, and much more .
Generally using Reactant compilation makes thing more likely to be efficiently differentiated by Enzyme. It removes type instabilities, preserves structure in IR, and lots of other good things that make differentiation (and also just the performance of the original code) better.
I have some private code where I added Enzyme support for FFTs (non-reactant version). I didn’t consider it production ready but I can share it clean it up and we could make a PR.
bfft is the correct adjoint operator (it is literally the adjoint, the conjugate-transpose, of fft). If you used inv(fft_plan) or ifft then it would introduce an additional 1/n scale factor that you don’t want.