# Slow backward gradients in ODEs

Hi all,

I am new to Julia and I am trying to solve an ODE and obtain its gradients wrt parameters p. The ODE is the Hodgkin-Huxley model, a standard model in neuroscience (full ODE at the end of this message).

I can solve the ODE as follows:

p = [70.0,10.0,0.05];
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
t_solve = @elapsed sol = solve(prob,Euler(),dt=0.01);
println("Time to solve:  ", t_solve)


And I can get its gradients with:

function G(p)
tmp_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
end


On my machine, it takes about 0.05 seconds to solve the ODE, but it takes about 3.5 seconds to obtain the gradient.

My question is:

• Am I doing anything obviously wrong?
• Why is backpropagation 70 times more expensive than the solve? E.g. in neural networks, the gradient usually costs about 1-4 times as much as the forward pass.

Note: I know that I should not be using Euler and that ForwardDiff is faster in my current example (the model I am eventually interested in has >50 parameters). I am just curious if there is some way of improving the efficiency of backward mode gradients. I have tried using Zygote, but the problem persists.

Thanks a lot for your help!

Code for ODE:

using ReverseDiff
using DifferentialEquations: DifferentialEquations
using DifferentialEquations

function my_ode(du, u, p, t)
E1 = u[1]
m1 = u[2]
h1 = u[3]
n1 = u[4]

V_T = -63.0
g_Na1, g_Kd1, g_leak1 = p

alpha_act1 = -0.32 * (E1 - V_T - 13) / (exp(-(E1 - V_T - 13) / 4) - 1)
beta_act1 = 0.28 * (E1 - V_T - 40) / (exp((E1 - V_T - 40) / 5) - 1)
du[2] = (alpha_act1 * (1.0 - m1)) - (beta_act1 * m1)

alpha_inact1 = 0.128 * exp(-(E1 - V_T - 17) / 18)
beta_inact1 = 4 / (1 + exp(-(E1 - V_T - 40) / 5))
du[3] = (alpha_inact1 * (1.0 - h1)) - (beta_inact1 * h1)

alpha_kal1 = -0.032 * (E1 - V_T - 15) / (exp(-(E1 - V_T - 15) / 5) - 1)
beta_kal1 = 0.5 * exp(-(E1 - V_T - 10) / 40)
du[4] = (alpha_kal1 * (1.0 - n1)) - (beta_kal1 * n1)

area = 20000.0
I_Na1 = -(E1 - 50) * g_Na1 * area * (m1^3) * h1
I_K1 = -(E1 - (-90)) * g_Kd1 * area * n1^4
I_leak1 = -(E1 - (-65)) * g_leak1 * area

I_ext = 1.0 * area
du[1] = (I_leak1 + I_K1 + I_Na1 + I_ext) / (1 * area)
end


Interesting. Benchmarking in global scope and with @elapsed is always error prone. If I have done everything correctly

function test1()
p = [70.0,10.0,0.05];
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
sol = solve(prob,Euler(),dt=0.01);
end

@btime test1()

function G(p)
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
tmp_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
end

function test2()
p = [70.0,10.0,0.05];
end

@btime test2()


yields

  16.679 ms (240307 allocations: 24.58 MiB)
2.251 s (24660957 allocations: 977.22 MiB)


for me. You could also const-ify your global variables. @elapsed is potentially measuring compilation and runtime.

Edit: corrected mistakes.

1 Like

Thanks a lot! The difference indeed seems to stem from working in global scope (substituting btime with elapsed in your example also works fine). Thank you!

Unfortunately, I have to reopen this. There is a bug in the answer: In the function G(p), the input p is unused and overwritten. After fixing this bug, the runtime for the gradient is as high as posted originally

1 Like

Two more data points: as you mentioned ForwardDiff is good here:

function test3()
p = [70.0,10.0,0.05];
end

@btime test3()


with

 24.930 ms (240307 allocations: 48.38 MiB)


Zygote fails with a compile error promotion of types Static.False and Int64 failed to change any arguments.

One more update (got Zygote working):

function test1()
p = [70.0,10.0,0.05];
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
sol = solve(prob,Euler(),dt=0.01,save_at=600.0);
end

@show test1()

G1(p) = begin
u0 = convert.(eltype(p), [-59;0.2;0.9;0.1]);
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
sol = sum(solve(prob,Euler(),dt=0.01,save_at=600.0))
end

G2(p) = begin
u0 = convert.(eltype(p), [-59;0.2;0.9;0.1]);
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
end

function test2()
p = [70.0,10.0,0.05];
end

@show test2()

function test3()
p = [70.0,10.0,0.05];
end

@show test3()

function test4()
p = [70.0,10.0,0.05];
end

@show test4()

@btime test1()
@btime test2()
@btime test3()
@btime test4()


yields

  17.118 ms (240445 allocations: 24.59 MiB)
2.133 s (24661093 allocations: 977.24 MiB) # ReverseDiff
25.753 ms (240487 allocations: 48.40 MiB) # ForwardDiff
168.403 ms (1859063 allocations: 183.53 MiB) # Zygote


so Zygote does slightly better here. I also tried to optimize the ReverseDiff case using ReverseDiff.@forward without much success. Havenâ€™t checked yet if Zygote has a similar mechanism. One more observation: possible selections of sensealg seem to be dependent from AD.

First of all, this model is tiny. If you go by the results of [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions, you need about 100 ODEs to overcome the cost of reverse-mode AD in the best case, in the worst case even more than 1000. Reverse-mode just generally has a higher overhead but better scaling. So in this case, it shouldnâ€™t be surprising that ForwardDiff is best.

For the other result, Zygote plugs into the DiffEqSensitivity setup Local Sensitivity Analysis (Automatic Differentiation) Â· DifferentialEquations.jl and so it ends up â€śnot being too badâ€ť because itâ€™s using a hand-written adjoint rule, where all of the internal work is probably delegated to a vjp defined by Enzyme to handle the mutations well.

Iâ€™m not sure if ReverseDiff.jl ends up hitting the same rules? I think we defined a connection here DiffEqBase.jl/reversediff.jl at master Â· SciML/DiffEqBase.jl Â· GitHub but Iâ€™d need to double check if that actually catches it. The problem with ReverseDiff.jl though is that if you use it in Array{TrackedReal} accidentally (instead of TrackedArray), you can miss dispatches like this and get some huge performance drops. I think the only real application for ReverseDiff.jl at the user level is for compiled tape mode on cases where Array{TrackedReal} is used to avoid errors (i.e. to handle mutation).

2 Likes

Fantastic, thanks a lot for the detailed reply!