I tried to expand the DiffEqFlux page’s model to a simple one involving an FFT, but I can’t seem to get it working with any of the sensalg options. Here’s what I’ve got,
using FFTW
using DiffEqSensitivity, OrdinaryDiffEq, Zygote
nlon = 64
nlat = 64
K = (1:(div(nlat,2)+1)) .* (1:nlon)' / (nlat*nlon)
function f(x,p,t)
k = p[1]
lc = k * irfft(K .* rfft(x), nlat)
@. x - x^3/3 + lc
end
p = [1.5];
u0 = randn(nlat, nlon);
prob = ODEProblem(f, u0, (0.0, 10.0), p);
loss(u0,p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1))
du01,dp1 = Zygote.gradient(loss,u0,p)
which results in a type error because the FFT package doesn’t support TrackedReals. Doesn’t the AD system recognise what the gradient of an FFT should be and not try to diff through FFTW? This just works in the autograd package I use in Python.
Something funky is happening.
Because you shouldn’t be seeing any TrackedReals.
That means that either Tracker.jl or ReverseDiff.jl is being used.
Where as that says you are using Zygote.
Can you post the full error message and stacktrack?
Thanks for the replies. I initially tried an in-place form, as in the DiffEqFlux example, but then I thought it might work better out of place. I put the example, traceback and pkg status in this gist