Hi!
For the third time, I am trying to implement some wave optics algorithms in Julia but each time I start, I suffer a bit from the current automatic differentiation status.
In principle, my functions look like:
function propagate(field::AbstractArray{T, 6}, distance::NumberOrArray,
wavelengths::NumberOrArray, pixel_size)
H = # some kernel calculated with distances and wavelengths
field_p = # some processing on the field with wavelengths and distance
field_prop = ifft(fft(field_p) .* H) # roughly
return field_prop
end
The adjoint is a bit nasty to write manually since distance
or wavelengths
can be both scalars or vectors. The output will be 6 dimensional ((x,y,z, wavelengths, polarization, batch)
), but some dimensions are singleton depending on the types of distance
and wavelengths
.
Anyway, it seems like Enzyme.jl is not supporting FFTs at the moment (latest is this issue or older this). So in my current code I use a mix of Zygote and custom written ChainRules.jl.
I’d love to use DifferentiationInterface.jl but if FFTs do not work, that’s gonna be a hard time.
Has anyone an idea what I should do? I feel like using Zygote.jl is not quite future proof as everyone else seem to switch to Enzyme.jl.
Honestly, I seriously consider to implement it in another language because this pain has been ongoing for years (I started hitting those general issues in 2020). I know there is progress on this front and things have been improved but they haven’t reached my feature demands yet (CUDA + FFT + complex numbers). And implementing an Enzyme rule seems just very hard for me. I would love to help out a bit but right now I want to focus on my real research problems.
I see similar packages implemented in PyTorch and JAX and it works well for them, so I am really a bit hopeless and cannot recommend Julia to them right now.
Best,
Felix