I tried to expand the DiffEqFlux page’s model to a simple one involving an FFT, but I can’t seem to get it working with any of the sensalg options. Here’s what I’ve got,

using FFTW
using DiffEqSensitivity, OrdinaryDiffEq, Zygote
nlon = 64
nlat = 64
K = (1:(div(nlat,2)+1)) .* (1:nlon)' / (nlat*nlon)
function f(x,p,t)
k = p[1]
lc = k * irfft(K .* rfft(x), nlat)
@. x - x^3/3 + lc
end
p = [1.5];
u0 = randn(nlat, nlon);
prob = ODEProblem(f, u0, (0.0, 10.0), p);
loss(u0,p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1))
du01,dp1 = Zygote.gradient(loss,u0,p)

which results in a type error because the FFT package doesn’t support TrackedReals. Doesn’t the AD system recognise what the gradient of an FFT should be and not try to diff through FFTW? This just works in the autograd package I use in Python.

Something funky is happening.
Because you shouldn’t be seeing any TrackedReals.
That means that either Tracker.jl or ReverseDiff.jl is being used.
Where as that says you are using Zygote.
Can you post the full error message and stacktrack?

@devmotion has a PR up to add the rrule and frule definitions to AbstractFFTs.

Once that lands any AD that supports ChainRules (Zygote, Yota, Diffractor, Nabla) should be able to AD irfft just fine.

Thanks for the replies. I initially tried an in-place form, as in the DiffEqFlux example, but then I thought it might work better out of place. I put the example, traceback and pkg status in this gist

Thanks for mentioning the PR, though there it refers to definitions in Zygote that seem to suggest Zygote should already know how to handle FFTs.