I tried to expand the DiffEqFlux page’s model to a simple one involving an FFT, but I can’t seem to get it working with any of the sensalg options. Here’s what I’ve got,
using DiffEqSensitivity, OrdinaryDiffEq, Zygote
nlon = 64
nlat = 64
K = (1:(div(nlat,2)+1)) .* (1:nlon)' / (nlat*nlon)
k = p
lc = k * irfft(K .* rfft(x), nlat)
@. x - x^3/3 + lc
p = [1.5];
u0 = randn(nlat, nlon);
prob = ODEProblem(f, u0, (0.0, 10.0), p);
loss(u0,p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1))
du01,dp1 = Zygote.gradient(loss,u0,p)
which results in a type error because the FFT package doesn’t support TrackedReals. Doesn’t the AD system recognise what the gradient of an FFT should be and not try to diff through FFTW? This just works in the autograd package I use in Python.
Something funky is happening.
Because you shouldn’t be seeing any TrackedReals.
That means that either Tracker.jl or ReverseDiff.jl is being used.
Where as that says you are using Zygote.
Can you post the full error message and stacktrack?
@devmotion has a PR up to add the rrule and frule definitions to AbstractFFTs.
Once that lands any AD that supports ChainRules (Zygote, Yota, Diffractor, Nabla) should be able to AD irfft just fine.