# Nested gradient computations with differential equation solver

Hi, I am exploring the possibilities of combining Julia’s differential equations solvers and automatic differentiation capabilities, and I am stuck while trying to backpropagate twice through an ode solver.

Here is some code to reproduce the problem :

module Minimal

using Zygote
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
du[1] = cos(t)
end

prob = ODEProblem(dsin, [x[1]], (1, 10))
sol = DifferentialEquations.solve(
prob, RK4(), saveat=collect(1:0.1:10),
)

solx = [xx[1] for xx in sol.u]

sum((solx - x) .^ 2)
end
g[1]
end

sum( (gt - (x0 - η*g)).^2 )
end
end
end



Here are my installed packages (Pkg.status()):

  [a077e3f3] DiffEqProblemLibrary v4.15.0
[41bf760c] DiffEqSensitivity v6.71.0
[0c46a032] DifferentialEquations v7.1.0
[61744808] DynamicalSystems v2.1.8
[587475ba] Flux v0.12.9
[7073ff75] IJulia v1.23.2
[a98d9a8b] Interpolations v0.13.5
[2b0e0bc5] LanguageServer v4.2.0
[30363a11] NetCDF v0.11.4
[1dea7af3] OrdinaryDiffEq v6.7.1
[91a5bcdd] Plots v1.27.3
[438e738f] PyCall v1.93.1
[e88e6eb3] Zygote v0.6.37


and julia version: Version 1.7.2
Any help greatly appreciated

Why backprop twice? Second order of an ODE solver is essentially always faster to do via forward-over-reverse, and that exists. It’ll happen automatically when using Newton:

and uses the following:

https://diffeq.sciml.ai/stable/analysis/sensitivity/#Second-Order-Sensitivity-Analysis-via-second_order_sensitivities

And note, when you do forward-over-reverse, you can do Hessian-free operations, i.e. Hv, to do Newton-Krylov without computing Hessians.

1 Like

The following seems to work for me

module Minimal

using Zygote
using ForwardDiff
using DifferentialEquations
using DiffEqSensitivity

t = 1:0.1:10
gt = sin.(t)
x0 = gt + randn(size(gt)...) * 0.2
η0 = 0.01

function dsin(du, u, p, t)
du[1] = cos(t)
end

prob = ODEProblem(dsin, [x[1]], (1, 10))
sol = DifferentialEquations.solve(
prob, RK4(), saveat=collect(1:0.1:10),
)

solx = [xx[1] for xx in sol.u]

sum((solx - x) .^ 2)
end
g[1]
end

ForwardDiff.derivative(η0) do η
sum( (gt - (x0 .- η*g)).^2 )
end
end
end



[EDIT]: code updated with working solution

@ChrisRackauckas , Thank you for your reply. Maybe just to give you some additional context: the idea behind my question is to optimize for the hyperparameters of a variational data assimilation scheme (4dvar), maybe my code was too minimal… (see below for the full code).

• The inner differentiation is part of a gradient descent to minimize a variational cost
• I want to then optimize for some hyperparameters of the gradient descent, i.e. backpropagating through the gradient descent

I’m not sure if your solution still applies, I need to read up to better understand the different term you introduced ( forward-over-reverse, Hessian free, Newton Krylov…) . I’ll try some stuff and come back here with a solution or some additional questions

module Tmp
using Zygote
using ForwardDiff
using ReverseDiff
using DifferentialEquations
using DiffEqSensitivity
using Plots

#  4d variational cost
function varcost(t, df, obs_idx, x, y)
prob = ODEProblem(df, [x], (1, 10))
sol = DifferentialEquations.solve(
prob, RK4(), saveat=collect(t),
)

solx = [xx[1] for xx in sol.u]
obs_cost = sum((solx[obs_idx] - y) .^ 2)

obs_cost
end

function varsolve(x0, df, t, xvarcost, niter, η)
x = deepcopy(x0)
for i in 1:niter
x = x - η * g
end
prob = ODEProblem(df, [x], (1, 10))
sol = DifferentialEquations.solve( prob, RK4(), saveat=collect(t),)

[xx[1] for xx in sol.u]
end

function outer_fit_rev(η0, ηvarsolve, loss, niter=1, lr=0.001)
η = deepcopy(η0)
for i in 1:niter
η = η - lr * gη
end
η
end
function outer_fit_fwd(η0, ηvarsolve, loss, niter=1, lr=0.001)
η = deepcopy(η0)
for i in 1:niter
gη = ForwardDiff.derivative(loss∘ηvarsolve, η)[1]
η = η - lr * gη
end
η
end
end

using Plots
using Statistics

t = 1:0.1:10
gt = sin.(t)
sub_samp = 3
ss_idx = 1:sub_samp:length(gt)
ss_gt = gt[ss_idx]
obs = ss_gt + randn(size(ss_gt)...) * 0.2

x0 = ss_gt[1] - 0.3

function dsin(du, u, p, t)
du[1] = cos(t)
end
xvarcost = (x) -> Tmp.varcost(t, dsin, ss_idx, x, obs)

niterinner = 1
η0 = 0.01

x_η0 = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η0)

niterouter = 2
lr = 0.001

ηvarsolve =  (η)-> Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)
training_loss = (x) -> mean((x - gt).^2)

@time η = Tmp.outer_fit_rev(η0, ηvarsolve, training_loss, niterouter, lr)

x_η = Tmp.varsolve(x0, dsin, t, xvarcost, niterinner, η)

plot(t, gt, label="gt")
scatter!(t[ss_idx], obs, label="obs")
plot!(t, x_η0, label="x_η0")
plot!(t, x_η, label="x_η")



Great !! indeed this solves my problem thanks a lot, but to be honest I don’t really understand why, could you maybe provide some insights or some documentation relating to the impact of this import ?

Whenever I encounter one of these incomprehensible error messages with Zygote I try to understand the problem using alternate AD packages.

Another test shows that

function outergrad()
sum( (gt - (x0 .- η.*g)).^2 )
end
end


seems to work here too, but Zygote still has problems to manage this.

1 Like

How many hyperparameters do you have? 1000? If you have like 20 hyperparameters, then using forward-mode will be faster on that. Reverse mode is for very specific things and really shouldn’t be thought of as a hammer.

In the code above only 1 parameter but my end goal would be to train some models with a few hundred thousand parameters.
I’ve played around with Flux.jl, and this is why went with Zygote.jl by default.
For your information, I am just starting out in Julia and have been using pytorch so far therefore I haven’t had to think too much about the AD part of my work so all your comments are very much appreciated

Not parameters, hyperparameters.

I apologize for the lack of clarity… Let me try again:

I have a parametrized function f that performs a gradient descent on a variational cost (that includes some differential eqs integration)

f_{\theta}(y) = \hat{x} = \underset{x}{argmin} J(y,x)

I want to learn the parameters \Theta of this function.

\Theta = \underset{\theta}{argmin}L(f_\theta)

In my example so far the only parameter of this function was the gradient step used to minimize J. I used the terminology hyperparameter (probably a little bit abusively) because I thought the problem was similar to optimizing for the learning rate in a classical ML problem.

In my target usecase this parametrized function contains a couple neural networks with a few hundred thousand parameters. They are not really hyperparameters but optimizing them do require nested gradient descent.

I hope this was clearer. And thanks again for taking the time