Nested gradient computations with differential equation solver

[EDIT]: code updated with working solution

@ChrisRackauckas , Thank you for your reply. Maybe just to give you some additional context: the idea behind my question is to optimize for the hyperparameters of a variational data assimilation scheme (4dvar), maybe my code was too minimal… (see below for the full code).

  • The inner differentiation is part of a gradient descent to minimize a variational cost
  • I want to then optimize for some hyperparameters of the gradient descent, i.e. backpropagating through the gradient descent

I’m not sure if your solution still applies, I need to read up to better understand the different term you introduced ( forward-over-reverse, Hessian free, Newton Krylov…) . I’ll try some stuff and come back here with a solution or some additional questions :slight_smile:

module Tmp
using Zygote
using ForwardDiff
using ReverseDiff
using DifferentialEquations
using DiffEqSensitivity
using Plots


#  4d variational cost
function varcost(t, df, obs_idx, x, y)
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve(
                prob, RK4(), saveat=collect(t),
	)

        solx = [xx[1] for xx in sol.u]
        obs_cost = sum((solx[obs_idx] - y) .^ 2)

        obs_cost 
end

# Gradient descent solver
function varsolve(x0, df, t, xvarcost, niter, η)
	x = deepcopy(x0)
	for i in 1:niter
		g = gradient(xvarcost, x)[1]
		x = x - η * g
	end
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve( prob, RK4(), saveat=collect(t),)

        [xx[1] for xx in sol.u]
end


function outer_fit_rev(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ReverseDiff.gradient(loss∘ηvarsolve∘first, [η], ReverseDiff.GradientConfig([η]))[1]
		η = η - lr * gη
	end
	η
end
function outer_fit_fwd(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ForwardDiff.derivative(loss∘ηvarsolve, η)[1]
		η = η - lr * gη
	end
	η
end
end

using Plots
using Statistics



t = 1:0.1:10
gt = sin.(t)
sub_samp = 3
ss_idx = 1:sub_samp:length(gt)
ss_gt = gt[ss_idx]
obs = ss_gt + randn(size(ss_gt)...) * 0.2

x0 = ss_gt[1] - 0.3

function dsin(du, u, p, t)
	du[1] = cos(t)
end
xvarcost = (x) -> Tmp.varcost(t, dsin, ss_idx, x, obs)

niterinner = 1
η0 = 0.01

x_η0 = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η0)


niterouter = 2
lr = 0.001

ηvarsolve =  (η)-> Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)
training_loss = (x) -> mean((x - gt).^2)

@time η = Tmp.outer_fit_rev(η0, ηvarsolve, training_loss, niterouter, lr)


x_η = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)

plot(t, gt, label="gt")
scatter!(t[ss_idx], obs, label="obs")
plot!(t, x_η0, label="x_η0")
plot!(t, x_η, label="x_η")