Hi, I am exploring the possibilities of combining Julia’s differential equations solvers and automatic differentiation capabilities, and I am stuck while trying to backpropagate twice through an ode solver.
Here is some code to reproduce the problem :
module Minimal
using Zygote
using DifferentialEquations
using DiffEqSensitivity
t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01
function dsin(du, u, p, t)
du[1] = cos(t)
end
function innergrad()
g = gradient(x0) do x
prob = ODEProblem(dsin, [x[1]], (1, 10))
sol = DifferentialEquations.solve(
prob, RK4(), saveat=collect(1:0.1:10),
)
solx = [xx[1] for xx in sol.u]
sum((solx - x) .^ 2)
end
g[1]
end
function outergrad()
gradient(η0) do η
g = innergrad()
sum( (gt - (x0 - η*g)).^2 )
end
end
end
Minimal.innergrad() # OK
Minimal.outergrad() # KO
Here are my installed packages (Pkg.status()
):
[a077e3f3] DiffEqProblemLibrary v4.15.0
[41bf760c] DiffEqSensitivity v6.71.0
[0c46a032] DifferentialEquations v7.1.0
[61744808] DynamicalSystems v2.1.8
[587475ba] Flux v0.12.9
[7073ff75] IJulia v1.23.2
[a98d9a8b] Interpolations v0.13.5
[2b0e0bc5] LanguageServer v4.2.0
[30363a11] NetCDF v0.11.4
[1dea7af3] OrdinaryDiffEq v6.7.1
[91a5bcdd] Plots v1.27.3
[438e738f] PyCall v1.93.1
[e88e6eb3] Zygote v0.6.37
and julia version: Version 1.7.2
Any help greatly appreciated