Hi all,
I am new to Julia and I am trying to solve an ODE and obtain its gradients wrt parameters p
. The ODE is the Hodgkin-Huxley model, a standard model in neuroscience (full ODE at the end of this message).
I can solve the ODE as follows:
p = [70.0,10.0,0.05];
u0 = [-59;0.2;0.9;0.1];
tspan = [0.0,600.0];
prob = ODEProblem(my_ode,u0,tspan,p)
t_solve = @elapsed sol = solve(prob,Euler(),dt=0.01);
println("Time to solve: ", t_solve)
And I can get its gradients with:
function G(p)
tmp_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
sol = sum(solve(tmp_prob,Euler(),dt=0.01,sensealg=SensitivityADPassThrough()))
end
t_grad = @elapsed grad = ReverseDiff.gradient(G,p)
println("Time to obtain grad: ", t_grad)
On my machine, it takes about 0.05
seconds to solve the ODE, but it takes about 3.5
seconds to obtain the gradient.
My question is:
- Am I doing anything obviously wrong?
- Why is backpropagation 70 times more expensive than the solve? E.g. in neural networks, the gradient usually costs about 1-4 times as much as the forward pass.
Note: I know that I should not be using Euler
and that ForwardDiff
is faster in my current example (the model I am eventually interested in has >50 parameters). I am just curious if there is some way of improving the efficiency of backward mode gradients. I have tried using Zygote
, but the problem persists.
Thanks a lot for your help!
Code for ODE:
using ReverseDiff
using DifferentialEquations: DifferentialEquations
using DifferentialEquations
function my_ode(du, u, p, t)
E1 = u[1]
m1 = u[2]
h1 = u[3]
n1 = u[4]
V_T = -63.0
g_Na1, g_Kd1, g_leak1 = p
alpha_act1 = -0.32 * (E1 - V_T - 13) / (exp(-(E1 - V_T - 13) / 4) - 1)
beta_act1 = 0.28 * (E1 - V_T - 40) / (exp((E1 - V_T - 40) / 5) - 1)
du[2] = (alpha_act1 * (1.0 - m1)) - (beta_act1 * m1)
alpha_inact1 = 0.128 * exp(-(E1 - V_T - 17) / 18)
beta_inact1 = 4 / (1 + exp(-(E1 - V_T - 40) / 5))
du[3] = (alpha_inact1 * (1.0 - h1)) - (beta_inact1 * h1)
alpha_kal1 = -0.032 * (E1 - V_T - 15) / (exp(-(E1 - V_T - 15) / 5) - 1)
beta_kal1 = 0.5 * exp(-(E1 - V_T - 10) / 40)
du[4] = (alpha_kal1 * (1.0 - n1)) - (beta_kal1 * n1)
area = 20000.0
I_Na1 = -(E1 - 50) * g_Na1 * area * (m1^3) * h1
I_K1 = -(E1 - (-90)) * g_Kd1 * area * n1^4
I_leak1 = -(E1 - (-65)) * g_leak1 * area
I_ext = 1.0 * area
du[1] = (I_leak1 + I_K1 + I_Na1 + I_ext) / (1 * area)
end