The following seems to work for me
module Minimal
using Zygote
using ForwardDiff
using DifferentialEquations
using DiffEqSensitivity
t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01
function dsin(du, u, p, t)
du[1] = cos(t)
end
function innergrad()
g = gradient(x0) do x
prob = ODEProblem(dsin, [x[1]], (1, 10))
sol = DifferentialEquations.solve(
prob, RK4(), saveat=collect(1:0.1:10),
)
solx = [xx[1] for xx in sol.u]
sum((solx - x) .^ 2)
end
g[1]
end
function outergrad()
ForwardDiff.derivative(η0) do η
g = innergrad()
sum( (gt - (x0 .- η*g)).^2 )
end
end
end
Minimal.innergrad() # OK
Minimal.outergrad() # KO