using DiffEqSensitivity, OrdinaryDiffEq, Zygote, StochasticDiffEq # similar to documentation, but out-of-place function fiip(u,p,t) du1 = dx = p*u - p*u*u du2 = dy = -p*u + p*u*u return [du1,du2] end function fiip_s(u,p,t) du1 = dx = 0.001 du2 = dy = 0.001 return [du1,du2] end p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] probSDE = SDEProblem(fiip,fiip_s,u0,(0.0,1.0),p) # changed terminal time to 1.0 solSDE = solve(probSDE,SOSRI()) # correct answer du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=TrackerAdjoint())),u0,p) # gradient with respect to u0 = nothing, which is wrong du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=ForwardDiffSensitivity())),u0,p) # error but docs say it is supported? du01SDE,dp1SDE = Zygote.gradient((u0,p)->sum(solve(probSDE,SOSRI(),u0=u0,p=p,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p)
FWIW, they all give the same answer. The forward-mode one just doesn’t return the
u0 part of the derivative (track the issue here: https://github.com/SciML/DiffEqSensitivity.jl/issues/156), but
nothing is correct for Zygote and if you try to use that gradient value you’ll properly get an error because right now our forward-mode overload doesn’t compute it.
ReverseDiffAdjoint just needs DistributionsAD and a ReverseDiff tag. I put those in motion and by tomorrow morning it should be good again. Sorry about that.
Thank you! Very helpful.
I had interpreted
nothing as a “generalized zero” as explained in this issue: https://github.com/FluxML/Zygote.jl/issues/329 so I thought it was “wrong” but now I see nothing can also mean “not yet implemented”.
Yeah, that issue is about how
Zero() could (and should) be different. But Zygote just has
nothing so you do have to be a bit careful.