Nested gradient computations with differential equation solver

Hi, I am exploring the possibilities of combining Julia’s differential equations solvers and automatic differentiation capabilities, and I am stuck while trying to backpropagate twice through an ode solver.

Here is some code to reproduce the problem :

module Minimal

using Zygote
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
	du[1] = cos(t)
end

function innergrad()
	g = gradient(x0) do x
		prob = ODEProblem(dsin, [x[1]], (1, 10))
		sol = DifferentialEquations.solve(
			prob, RK4(), saveat=collect(1:0.1:10),
		)

		solx = [xx[1] for xx in sol.u]

		sum((solx - x) .^ 2)
	end
	g[1]
end

function outergrad()
	gradient(η0) do η
		g = innergrad()
		sum( (gt - (x0 - η*g)).^2 )
	end
end
end

Minimal.innergrad() # OK
Minimal.outergrad() # KO

Here are my installed packages (Pkg.status()):

  [a077e3f3] DiffEqProblemLibrary v4.15.0
  [41bf760c] DiffEqSensitivity v6.71.0
  [0c46a032] DifferentialEquations v7.1.0
  [61744808] DynamicalSystems v2.1.8
  [587475ba] Flux v0.12.9
  [7073ff75] IJulia v1.23.2
  [a98d9a8b] Interpolations v0.13.5
  [2b0e0bc5] LanguageServer v4.2.0
  [30363a11] NetCDF v0.11.4
  [1dea7af3] OrdinaryDiffEq v6.7.1
  [91a5bcdd] Plots v1.27.3
  [438e738f] PyCall v1.93.1
  [e88e6eb3] Zygote v0.6.37

and julia version: Version 1.7.2
Any help greatly appreciated :slight_smile:

Why backprop twice? Second order of an ODE solver is essentially always faster to do via forward-over-reverse, and that exists. It’ll happen automatically when using Newton:

https://diffeqflux.sciml.ai/dev/examples/second_order_adjoints/

and uses the following:

https://diffeq.sciml.ai/stable/analysis/sensitivity/#Second-Order-Sensitivity-Analysis-via-second_order_sensitivities

And note, when you do forward-over-reverse, you can do Hessian-free operations, i.e. Hv, to do Newton-Krylov without computing Hessians.

1 Like

The following seems to work for me

module Minimal

using Zygote
using ForwardDiff
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
	du[1] = cos(t)
end

function innergrad()
	g = gradient(x0) do x
		prob = ODEProblem(dsin, [x[1]], (1, 10))
		sol = DifferentialEquations.solve(
			prob, RK4(), saveat=collect(1:0.1:10),
		)

		solx = [xx[1] for xx in sol.u]

		sum((solx - x) .^ 2)
	end
	g[1]
end

function outergrad()
	ForwardDiff.derivative(η0) do η
		g = innergrad()
		sum( (gt - (x0 .- η*g)).^2 )
	end
end
end

Minimal.innergrad() # OK
Minimal.outergrad() # KO

[EDIT]: code updated with working solution

@ChrisRackauckas , Thank you for your reply. Maybe just to give you some additional context: the idea behind my question is to optimize for the hyperparameters of a variational data assimilation scheme (4dvar), maybe my code was too minimal… (see below for the full code).

  • The inner differentiation is part of a gradient descent to minimize a variational cost
  • I want to then optimize for some hyperparameters of the gradient descent, i.e. backpropagating through the gradient descent

I’m not sure if your solution still applies, I need to read up to better understand the different term you introduced ( forward-over-reverse, Hessian free, Newton Krylov…) . I’ll try some stuff and come back here with a solution or some additional questions :slight_smile:

module Tmp
using Zygote
using ForwardDiff
using ReverseDiff
using DifferentialEquations
using DiffEqSensitivity
using Plots


#  4d variational cost
function varcost(t, df, obs_idx, x, y)
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve(
                prob, RK4(), saveat=collect(t),
	)

        solx = [xx[1] for xx in sol.u]
        obs_cost = sum((solx[obs_idx] - y) .^ 2)

        obs_cost 
end

# Gradient descent solver
function varsolve(x0, df, t, xvarcost, niter, η)
	x = deepcopy(x0)
	for i in 1:niter
		g = gradient(xvarcost, x)[1]
		x = x - η * g
	end
        prob = ODEProblem(df, [x], (1, 10))
        sol = DifferentialEquations.solve( prob, RK4(), saveat=collect(t),)

        [xx[1] for xx in sol.u]
end


function outer_fit_rev(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ReverseDiff.gradient(loss∘ηvarsolve∘first, [η], ReverseDiff.GradientConfig([η]))[1]
		η = η - lr * gη
	end
	η
end
function outer_fit_fwd(η0, ηvarsolve, loss, niter=1, lr=0.001)
	η = deepcopy(η0)
	for i in 1:niter
		gη = ForwardDiff.derivative(loss∘ηvarsolve, η)[1]
		η = η - lr * gη
	end
	η
end
end

using Plots
using Statistics



t = 1:0.1:10
gt = sin.(t)
sub_samp = 3
ss_idx = 1:sub_samp:length(gt)
ss_gt = gt[ss_idx]
obs = ss_gt + randn(size(ss_gt)...) * 0.2

x0 = ss_gt[1] - 0.3

function dsin(du, u, p, t)
	du[1] = cos(t)
end
xvarcost = (x) -> Tmp.varcost(t, dsin, ss_idx, x, obs)

niterinner = 1
η0 = 0.01

x_η0 = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η0)


niterouter = 2
lr = 0.001

ηvarsolve =  (η)-> Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)
training_loss = (x) -> mean((x - gt).^2)

@time η = Tmp.outer_fit_rev(η0, ηvarsolve, training_loss, niterouter, lr)


x_η = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)

plot(t, gt, label="gt")
scatter!(t[ss_idx], obs, label="obs")
plot!(t, x_η0, label="x_η0")
plot!(t, x_η, label="x_η")

Great !! indeed this solves my problem thanks a lot, but to be honest I don’t really understand why, could you maybe provide some insights or some documentation relating to the impact of this import ?

Whenever I encounter one of these incomprehensible error messages with Zygote I try to understand the problem using alternate AD packages.

Another test shows that

function outergrad()
	ReverseDiff.gradient([η0]) do η
		g = innergrad()
		sum( (gt - (x0 .- η.*g)).^2 )
	end
end

seems to work here too, but Zygote still has problems to manage this.

1 Like

How many hyperparameters do you have? 1000? If you have like 20 hyperparameters, then using forward-mode will be faster on that. Reverse mode is for very specific things and really shouldn’t be thought of as a hammer.

In the code above only 1 parameter but my end goal would be to train some models with a few hundred thousand parameters.
I’ve played around with Flux.jl, and this is why went with Zygote.jl by default.
For your information, I am just starting out in Julia and have been using pytorch so far therefore I haven’t had to think too much about the AD part of my work so all your comments are very much appreciated :wink:

Not parameters, hyperparameters.

I apologize for the lack of clarity… Let me try again:

I have a parametrized function f that performs a gradient descent on a variational cost (that includes some differential eqs integration)

f_{\theta}(y) = \hat{x} = \underset{x}{argmin} J(y,x)

I want to learn the parameters \Theta of this function.

\Theta = \underset{\theta}{argmin}L(f_\theta)

In my example so far the only parameter of this function was the gradient step used to minimize J. I used the terminology hyperparameter (probably a little bit abusively) because I thought the problem was similar to optimizing for the learning rate in a classical ML problem.

In my target usecase this parametrized function contains a couple neural networks with a few hundred thousand parameters. They are not really hyperparameters but optimizing them do require nested gradient descent.

I hope this was clearer. And thanks again for taking the time