Gradients of FFT

I tried to expand the DiffEqFlux page’s model to a simple one involving an FFT, but I can’t seem to get it working with any of the sensalg options. Here’s what I’ve got,

using FFTW
using DiffEqSensitivity, OrdinaryDiffEq, Zygote

nlon = 64
nlat = 64
K = (1:(div(nlat,2)+1)) .* (1:nlon)' / (nlat*nlon)

function f(x,p,t)
    k = p[1]
    lc = k * irfft(K .* rfft(x), nlat)
    @. x - x^3/3 + lc

p = [1.5];
u0 = randn(nlat, nlon);
prob = ODEProblem(f, u0, (0.0, 10.0), p);
loss(u0,p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1))
du01,dp1 = Zygote.gradient(loss,u0,p)

which results in a type error because the FFT package doesn’t support TrackedReals. Doesn’t the AD system recognise what the gradient of an FFT should be and not try to diff through FFTW? This just works in the autograd package I use in Python.

Something funky is happening.
Because you shouldn’t be seeing any TrackedReals.
That means that either Tracker.jl or ReverseDiff.jl is being used.
Where as that says you are using Zygote.
Can you post the full error message and stacktrack?

@devmotion has a PR up to add the rrule and frule definitions to AbstractFFTs.

Once that lands any AD that supports ChainRules (Zygote, Yota, Diffractor, Nabla) should be able to AD irfft just fine.

Til then you can try using their branch.

Yeah… for reference the current adjoint heuristic is at DiffEqSensitivity.jl/concrete_solve.jl at v6.60.2 · SciML/DiffEqSensitivity.jl · GitHub . Since your f is not in-place, it should be using Zygote for vjps and no tracked values should be seen. So I am confused as well and would like to see that stacktrace.

Thanks for the replies. I initially tried an in-place form, as in the DiffEqFlux example, but then I thought it might work better out of place. I put the example, traceback and pkg status in this gist

Thanks for mentioning the PR, though there it refers to definitions in Zygote that seem to suggest Zygote should already know how to handle FFTs.

Trying out different algorithms listed in the docs, sensalg=BacksolveAdjoint() seems to have done the trick.

edit this only works for the out of place case.

On the latest versions this all works out of the box.

Yes, because scalar tracing the FFT won’t work (at least without Enzyme support)