Nested gradient computations with differential equation solver

The following seems to work for me

module Minimal

using Zygote
using ForwardDiff
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
	du[1] = cos(t)
end

function innergrad()
	g = gradient(x0) do x
		prob = ODEProblem(dsin, [x[1]], (1, 10))
		sol = DifferentialEquations.solve(
			prob, RK4(), saveat=collect(1:0.1:10),
		)

		solx = [xx[1] for xx in sol.u]

		sum((solx - x) .^ 2)
	end
	g[1]
end

function outergrad()
	ForwardDiff.derivative(η0) do η
		g = innergrad()
		sum( (gt - (x0 .- η*g)).^2 )
	end
end
end

Minimal.innergrad() # OK
Minimal.outergrad() # KO