using DiffEqSensitivity, OrdinaryDiffEq, Zygote, StochasticDiffEq
# similar to documentation, but out-of-place
function fiip(u,p,t)
du1 = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du2 = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
return [du1,du2]
end
function fiip_s(u,p,t)
du1 = dx = 0.001
du2 = dy = 0.001
return [du1,du2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
probSDE = SDEProblem(fiip,fiip_s,u0,(0.0,1.0),p) # changed terminal time to 1.0
solSDE = solve(probSDE,SOSRI())
# correct answer
du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=TrackerAdjoint())),u0,p)
# gradient with respect to u0 = nothing, which is wrong
du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=ForwardDiffSensitivity())),u0,p)
# error but docs say it is supported?
du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p)
FWIW, they all give the same answer. The forward-mode one just doesn’t return the u0
part of the derivative (track the issue here: https://github.com/SciML/DiffEqSensitivity.jl/issues/156), but nothing
is correct for Zygote and if you try to use that gradient value you’ll properly get an error because right now our forward-mode overload doesn’t compute it.
ReverseDiffAdjoint just needs DistributionsAD and a ReverseDiff tag. I put those in motion and by tomorrow morning it should be good again. Sorry about that.
Thank you! Very helpful.
I had interpreted nothing
as a “generalized zero” as explained in this issue: https://github.com/FluxML/Zygote.jl/issues/329 so I thought it was “wrong” but now I see nothing can also mean “not yet implemented”.
Thanks again!
Yeah, that issue is about how nothing
and Zero()
could (and should) be different. But Zygote just has nothing
so you do have to be a bit careful.