Slow backward gradients in ODEs

Hi all,

I am new to Julia and I am trying to solve an ODE and obtain its gradients wrt parameters p. The ODE is the Hodgkin-Huxley model, a standard model in neuroscience (full ODE at the end of this message).

I can solve the ODE as follows:

p = [70.0,10.0,0.05];
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
t_solve = @elapsed sol = solve(prob,Euler(),dt=0.01);
println("Time to solve:  ", t_solve)

And I can get its gradients with:

function G(p)
  tmp_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
  sol = sum(solve(tmp_prob,Euler(),dt=0.01,sensealg=SensitivityADPassThrough()))
end
t_grad = @elapsed grad = ReverseDiff.gradient(G,p)
println("Time to obtain grad:  ", t_grad)

On my machine, it takes about 0.05 seconds to solve the ODE, but it takes about 3.5 seconds to obtain the gradient.

My question is:

  • Am I doing anything obviously wrong?
  • Why is backpropagation 70 times more expensive than the solve? E.g. in neural networks, the gradient usually costs about 1-4 times as much as the forward pass.

Note: I know that I should not be using Euler and that ForwardDiff is faster in my current example (the model I am eventually interested in has >50 parameters). I am just curious if there is some way of improving the efficiency of backward mode gradients. I have tried using Zygote, but the problem persists.

Thanks a lot for your help!

Code for ODE:

using ReverseDiff
using DifferentialEquations: DifferentialEquations
using DifferentialEquations

function my_ode(du, u, p, t)
    E1 = u[1]
    m1 = u[2]
    h1 = u[3]
    n1 = u[4]
    
    V_T = -63.0
    g_Na1, g_Kd1, g_leak1 = p
    
    alpha_act1 = -0.32 * (E1 - V_T - 13) / (exp(-(E1 - V_T - 13) / 4) - 1)
    beta_act1 = 0.28 * (E1 - V_T - 40) / (exp((E1 - V_T - 40) / 5) - 1)
    du[2] = (alpha_act1 * (1.0 - m1)) - (beta_act1 * m1)

    alpha_inact1 = 0.128 * exp(-(E1 - V_T - 17) / 18)
    beta_inact1 = 4 / (1 + exp(-(E1 - V_T - 40) / 5))
    du[3] = (alpha_inact1 * (1.0 - h1)) - (beta_inact1 * h1)

    alpha_kal1 = -0.032 * (E1 - V_T - 15) / (exp(-(E1 - V_T - 15) / 5) - 1)
    beta_kal1 = 0.5 * exp(-(E1 - V_T - 10) / 40)
    du[4] = (alpha_kal1 * (1.0 - n1)) - (beta_kal1 * n1)
    
    area = 20000.0
    I_Na1 = -(E1 - 50) * g_Na1 * area * (m1^3) * h1
    I_K1 = -(E1 - (-90)) * g_Kd1 * area * n1^4
    I_leak1 = -(E1 - (-65)) * g_leak1 * area
    
    I_ext = 1.0 * area
    du[1] = (I_leak1 + I_K1 + I_Na1 + I_ext) / (1 * area)
end

Interesting. Benchmarking in global scope and with @elapsed is always error prone. If I have done everything correctly

function test1()
    p = [70.0,10.0,0.05];
    u0 = [-59;0.2;0.9;0.1];
    tspan = [0.0,600.0];
    prob = ODEProblem(my_ode,u0,tspan,p)
    sol = solve(prob,Euler(),dt=0.01);
end

@btime test1()

function G(p)
    u0 = [-59;0.2;0.9;0.1];
    tspan = [0.0,600.0];
    prob = ODEProblem(my_ode,u0,tspan,p)
    tmp_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
    sol = sum(solve(tmp_prob,Euler(),dt=0.01,sensealg=SensitivityADPassThrough()))
end

function test2()
    p = [70.0,10.0,0.05];
    grad = ReverseDiff.gradient(G,p)
end

@btime test2()

yields

  16.679 ms (240307 allocations: 24.58 MiB)
  2.251 s (24660957 allocations: 977.22 MiB)

for me. You could also const-ify your global variables. @elapsed is potentially measuring compilation and runtime.

Edit: corrected mistakes.

1 Like

Thanks a lot! The difference indeed seems to stem from working in global scope (substituting btime with elapsed in your example also works fine). Thank you!

Unfortunately, I have to reopen this. There is a bug in the answer: In the function G(p), the input p is unused and overwritten. After fixing this bug, the runtime for the gradient is as high as posted originally

1 Like

Two more data points: as you mentioned ForwardDiff is good here:

function test3()
    p = [70.0,10.0,0.05];
    grad = ForwardDiff.gradient(G,p)
end

@btime test3()

with

 24.930 ms (240307 allocations: 48.38 MiB)

Zygote fails with a compile error promotion of types Static.False and Int64 failed to change any arguments.

One more update (got Zygote working):

function test1()
    p = [70.0,10.0,0.05];
    u0 = [-59;0.2;0.9;0.1];
    tspan = [0.0,600.0];
    prob = ODEProblem(my_ode,u0,tspan,p)
    sol = solve(prob,Euler(),dt=0.01,save_at=600.0);
end

@show test1()

G1(p) = begin
    u0 = convert.(eltype(p), [-59;0.2;0.9;0.1]);
    tspan = [0.0,600.0];
    prob = ODEProblem(my_ode,u0,tspan,p)
    sol = sum(solve(prob,Euler(),dt=0.01,save_at=600.0))
end

G2(p) = begin
    u0 = convert.(eltype(p), [-59;0.2;0.9;0.1]);
    tspan = [0.0,600.0];
    prob = ODEProblem(my_ode,u0,tspan,p)
    sol = sum(solve(prob,Euler(),dt=0.01,save_at=600.0,sensealg=SensitivityADPassThrough()))
end

function test2()
    p = [70.0,10.0,0.05];
    grad = ReverseDiff.gradient(G2,p)
end

@show test2()

function test3()
    p = [70.0,10.0,0.05];
    grad = ForwardDiff.gradient(G1,p)
end

@show test3()

function test4()
    p = [70.0,10.0,0.05];
    grad = Zygote.gradient(G1,p)
end

@show test4()

@btime test1()
@btime test2()
@btime test3()
@btime test4()

yields

  17.118 ms (240445 allocations: 24.59 MiB)
  2.133 s (24661093 allocations: 977.24 MiB) # ReverseDiff
  25.753 ms (240487 allocations: 48.40 MiB) # ForwardDiff
  168.403 ms (1859063 allocations: 183.53 MiB) # Zygote

so Zygote does slightly better here. I also tried to optimize the ReverseDiff case using ReverseDiff.@forward without much success. Haven’t checked yet if Zygote has a similar mechanism. One more observation: possible selections of sensealg seem to be dependent from AD.

First of all, this model is tiny. If you go by the results of [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions, you need about 100 ODEs to overcome the cost of reverse-mode AD in the best case, in the worst case even more than 1000. Reverse-mode just generally has a higher overhead but better scaling. So in this case, it shouldn’t be surprising that ForwardDiff is best.

For the other result, Zygote plugs into the DiffEqSensitivity setup Local Sensitivity Analysis (Automatic Differentiation) · DifferentialEquations.jl and so it ends up “not being too bad” because it’s using a hand-written adjoint rule, where all of the internal work is probably delegated to a vjp defined by Enzyme to handle the mutations well.

I’m not sure if ReverseDiff.jl ends up hitting the same rules? I think we defined a connection here https://github.com/SciML/DiffEqBase.jl/blob/master/src/reversediff.jl#L47-L50 but I’d need to double check if that actually catches it. The problem with ReverseDiff.jl though is that if you use it in Array{TrackedReal} accidentally (instead of TrackedArray), you can miss dispatches like this and get some huge performance drops. I think the only real application for ReverseDiff.jl at the user level is for compiled tape mode on cases where Array{TrackedReal} is used to avoid errors (i.e. to handle mutation).

3 Likes

Fantastic, thanks a lot for the detailed reply!