Nested gradient computations with differential equation solver

Hi, I am exploring the possibilities of combining Julia’s differential equations solvers and automatic differentiation capabilities, and I am stuck while trying to backpropagate twice through an ode solver.

Here is some code to reproduce the problem :

module Minimal

using Zygote
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
	du[1] = cos(t)
end

function innergrad()
	g = gradient(x0) do x
		prob = ODEProblem(dsin, [x[1]], (1, 10))
		sol = DifferentialEquations.solve(
			prob, RK4(), saveat=collect(1:0.1:10),
		)

		solx = [xx[1] for xx in sol.u]

		sum((solx - x) .^ 2)
	end
	g[1]
end

function outergrad()
	gradient(η0) do η
		g = innergrad()
		sum( (gt - (x0 - η*g)).^2 )
	end
end
end

Minimal.innergrad() # OK
Minimal.outergrad() # KO

Here are my installed packages (Pkg.status()):

  [a077e3f3] DiffEqProblemLibrary v4.15.0
  [41bf760c] DiffEqSensitivity v6.71.0
  [0c46a032] DifferentialEquations v7.1.0
  [61744808] DynamicalSystems v2.1.8
  [587475ba] Flux v0.12.9
  [7073ff75] IJulia v1.23.2
  [a98d9a8b] Interpolations v0.13.5
  [2b0e0bc5] LanguageServer v4.2.0
  [30363a11] NetCDF v0.11.4
  [1dea7af3] OrdinaryDiffEq v6.7.1
  [91a5bcdd] Plots v1.27.3
  [438e738f] PyCall v1.93.1
  [e88e6eb3] Zygote v0.6.37

and julia version: Version 1.7.2
Any help greatly appreciated :slight_smile:

Why backprop twice? Second order of an ODE solver is essentially always faster to do via forward-over-reverse, and that exists. It’ll happen automatically when using Newton:

https://diffeqflux.sciml.ai/dev/examples/second_order_adjoints/

and uses the following:

https://diffeq.sciml.ai/stable/analysis/sensitivity/#Second-Order-Sensitivity-Analysis-via-second_order_sensitivities

And note, when you do forward-over-reverse, you can do Hessian-free operations, i.e. Hv, to do Newton-Krylov without computing Hessians.

The following seems to work for me

module Minimal

using Zygote
using ForwardDiff
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
	du[1] = cos(t)
end

function innergrad()
	g = gradient(x0) do x
		prob = ODEProblem(dsin, [x[1]], (1, 10))
		sol = DifferentialEquations.solve(
			prob, RK4(), saveat=collect(1:0.1:10),
		)

		solx = [xx[1] for xx in sol.u]

		sum((solx - x) .^ 2)
	end
	g[1]
end

function outergrad()
	ForwardDiff.derivative(η0) do η
		g = innergrad()
		sum( (gt - (x0 .- η*g)).^2 )
	end
end
end

Minimal.innergrad() # OK
Minimal.outergrad() # KO

[EDIT]: code updated with working solution

@ChrisRackauckas , Thank you for your reply. Maybe just to give you some additional context: the idea behind my question is to optimize for the hyperparameters of a variational data assimilation scheme (4dvar), maybe my code was too minimal… (see below for the full code).

  • The inner differentiation is part of a gradient descent to minimize a variational cost
  • I want to then optimize for some hyperparameters of the gradient descent, i.e. backpropagating through the gradient descent

I’m not sure if your solution still applies, I need to read up to better understand the different term you introduced ( forward-over-reverse, Hessian free, Newton Krylov…) . I’ll try some stuff and come back here with a solution or some additional questions :slight_smile:

module Tmp
using Zygote
using ForwardDiff
using ReverseDiff
using DifferentialEquations
using DiffEqSensitivity
using Plots


#  4d variational cost
function varcost(t, df, obs_idx, x, y)
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve(
                prob, RK4(), saveat=collect(t),
	)

        solx = [xx[1] for xx in sol.u]
        obs_cost = sum((solx[obs_idx] - y) .^ 2)

        obs_cost 
end

# Gradient descent solver
function varsolve(x0, df, t, xvarcost, niter, η)
	x = deepcopy(x0)
	for i in 1:niter
		g = gradient(xvarcost, x)[1]
		x = x - η * g
	end
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve( prob, RK4(), saveat=collect(t),)

        [xx[1] for xx in sol.u]
end


function outer_fit_rev(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ReverseDiff.gradient(loss∘ηvarsolve∘first, [η], ReverseDiff.GradientConfig([η]))[1]
		η = η - lr * gη
	end
	η
end
function outer_fit_fwd(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ForwardDiff.derivative(loss∘ηvarsolve, η)[1]
		η = η - lr * gη
	end
	η
end
end

using Plots
using Statistics



t = 1:0.1:10
gt = sin.(t)
sub_samp = 3
ss_idx = 1:sub_samp:length(gt)
ss_gt = gt[ss_idx]
obs = ss_gt + randn(size(ss_gt)...) * 0.2

x0 = ss_gt[1] - 0.3

function dsin(du, u, p, t)
	du[1] = cos(t)
end
xvarcost = (x) -> Tmp.varcost(t, dsin, ss_idx, x, obs)

niterinner = 1
η0 = 0.01

x_η0 = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η0)


niterouter = 2
lr = 0.001

ηvarsolve =  (η)-> Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)
training_loss = (x) -> mean((x - gt).^2)

@time η = Tmp.outer_fit_rev(η0, ηvarsolve, training_loss, niterouter, lr)


x_η = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)

plot(t, gt, label="gt")
scatter!(t[ss_idx], obs, label="obs")
plot!(t, x_η0, label="x_η0")
plot!(t, x_η, label="x_η")

Great !! indeed this solves my problem thanks a lot, but to be honest I don’t really understand why, could you maybe provide some insights or some documentation relating to the impact of this import ?

Whenever I encounter one of these incomprehensible error messages with Zygote I try to understand the problem using alternate AD packages.

Another test shows that

function outergrad()
	ReverseDiff.gradient([η0]) do η
		g = innergrad()
		sum( (gt - (x0 .- η.*g)).^2 )
	end
end

seems to work here too, but Zygote still has problems to manage this.

How many hyperparameters do you have? 1000? If you have like 20 hyperparameters, then using forward-mode will be faster on that. Reverse mode is for very specific things and really shouldn’t be thought of as a hammer.

In the code above only 1 parameter but my end goal would be to train some models with a few hundred thousand parameters.
I’ve played around with Flux.jl, and this is why went with Zygote.jl by default.
For your information, I am just starting out in Julia and have been using pytorch so far therefore I haven’t had to think too much about the AD part of my work so all your comments are very much appreciated :wink:

Not parameters, hyperparameters.

I apologize for the lack of clarity… Let me try again:

I have a parametrized function f that performs a gradient descent on a variational cost (that includes some differential eqs integration)

f_{\theta}(y) = \hat{x} = \underset{x}{argmin} J(y,x)

I want to learn the parameters \Theta of this function.

\Theta = \underset{\theta}{argmin}L(f_\theta)

In my example so far the only parameter of this function was the gradient step used to minimize J. I used the terminology hyperparameter (probably a little bit abusively) because I thought the problem was similar to optimizing for the learning rate in a classical ML problem.

In my target usecase this parametrized function contains a couple neural networks with a few hundred thousand parameters. They are not really hyperparameters but optimizing them do require nested gradient descent.

I hope this was clearer. And thanks again for taking the time