Hello,
I’m trying to understand how ForwardDiff propagates Duals through an ode solver, so how the function defining an ODE system can be overloaded (specialized is another word that could be used?) to handle Duals.
Taking the simple Hodgkin Huxley example from DifferentialEquations.jl, if you’re interested in getting the gradient of the final voltage with respect to conductance parameters then you’d do:
using DifferentialEquations, ForwardDiff
# Potassium ion-channel rate functions
alpha_n(v) = (0.02 * (v - 25.0)) / (1.0 - exp((-1.0 * (v - 25.0)) / 9.0))
beta_n(v) = (-0.002 * (v - 25.0)) / (1.0 - exp((v - 25.0) / 9.0))
# Sodium ion-channel rate functions
alpha_m(v) = (0.182 * (v + 35.0)) / (1.0 - exp((-1.0 * (v + 35.0)) / 9.0))
beta_m(v) = (-0.124 * (v + 35.0)) / (1.0 - exp((v + 35.0) / 9.0))
alpha_h(v) = 0.25 * exp((-1.0 * (v + 90.0)) / 12.0)
beta_h(v) = (0.25 * exp((v + 62.0) / 6.0)) / exp((v + 90.0) / 12.0)
function HH!(du, u, p, t)
EK, ENa, EL, C, I = -77.0, 55.0, -65.0, 1, 1
gK,gNa, gL = p
v, n, m, h = u
du[1] = (-(gK * (n^4.0) * (v - EK)) - (gNa * (m^3.0) * h * (v - ENa)) -
(gL * (v - EL)) + I) / C
du[2] = (alpha_n(v) * (1.0 - n)) - (beta_n(v) * n)
du[3] = (alpha_m(v) * (1.0 - m)) - (beta_m(v) * m)
du[4] = (alpha_h(v) * (1.0 - h)) - (beta_h(v) * h)
end
n_inf(v) = alpha_n(v) / (alpha_n(v) + beta_n(v))
m_inf(v) = alpha_m(v) / (alpha_m(v) + beta_m(v))
h_inf(v) = alpha_h(v) / (alpha_h(v) + beta_h(v))
p = [35.0,40.0, 0.3]
u0 = [-60, n_inf(-60), m_inf(-60), h_inf(-60)]
HHprob = ODEProblem(HH!, u0, (0.0, 1000), p)
sol = solve(HHprob);
function loss(p)
sol = solve(HHprob, Tsit5(), p=p)
return sol[end][1]
end
loss(p)
grad = @time ForwardDiff.gradient(loss, p)
@show grad
The following should be equivalent since the custom derivatives I’ve defined below are in fact the derivatives that you’d get in the right hand side of the ode system – and in fact, only in the voltage equation.
# Now TEST CUSTOM ForwardDiff handling of Duals
function myHH!(du, u, p, t)
EK, ENa, EL, C, I = -77.0, 55.0, -65.0, 1, 1
gK, gNa, gL = p
# Unpack the state variables (v: membrane potential, n, m, h: gating variables)
v, n, m, h = u
if v isa ForwardDiff.Dual # Check if we're dealing with Dual numbers (ForwardDiff)
println("Using Dual numbers for differentiation")
# Extract primal values (normal numerical values)
v_val = ForwardDiff.value(v)
n_val = ForwardDiff.value(n)
m_val = ForwardDiff.value(m)
h_val = ForwardDiff.value(h)
# Compute the primal ODE equations (as before)
du_primal1 = (-(gK * (n_val^4.0) * (v_val - EK)) -
(gNa * (m_val^3.0) * h_val * (v_val - ENa)) -
(gL * (v_val - EL)) + I) / C
du_primal2 = (alpha_n(v_val) * (1.0 - n_val)) - (beta_n(v_val) * n_val)
du_primal3 = (alpha_m(v_val) * (1.0 - m_val)) - (beta_m(v_val) * m_val)
du_primal4 = (alpha_h(v_val) * (1.0 - h_val)) - (beta_h(v_val) * h_val)
# Get the partial derivatives for each state variable
v_partials = ForwardDiff.partials(v)
# Calculate the derivatives of the ODE system RHS with respect to the parameters
custom_deriv_gK = -(n_val^4.0) * (v_val - EK) * v_partials[1]
custom_deriv_gNa = -(m_val^3.0) * h_val * (v_val - ENa) * v_partials[2]
custom_deriv_gL = -(v_val - EL) * v_partials[3]
# Create the dual values by combining primal values and derivatives
du[1] = ForwardDiff.Dual(du_primal1, (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL))
du[2] = du_primal2
du[3] = du_primal3
du[4] = du_primal4
else
# Standard ODE computation when not dealing with Duals (normal floating point numbers)
du[1] = (-(gK * (n^4.0) * (v - EK)) -
(gNa * (m^3.0) * h * (v - ENa)) -
(gL * (v - EL)) + I) / C
du[2] = (alpha_n(v) * (1.0 - n)) - (beta_n(v) * n)
du[3] = (alpha_m(v) * (1.0 - m)) - (beta_m(v) * m)
du[4] = (alpha_h(v) * (1.0 - h)) - (beta_h(v) * h)
end
end
myHHprob = ODEProblem(myHH!, u0, (0.0, 1000), p)
mysol = solve(myHHprob);
function myloss(p)
sol = solve(myHHprob, Tsit5(), p=p)
return sol[end][1]
end
myloss(p)
mygrad = @time ForwardDiff.gradient(myloss, p)
@show mygrad
But I’m getting an error in the construction of
du[1]=ForwardDiff.Dual(du_primal1, (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL))
ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, ForwardDiff.Dual{ForwardDiff.Tag{typeof(myloss), Float64}, Float64, 3}, 3})
What’s wrong with passing the tuple (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL) to the second argument of Dual ? I can’t see where this might be forcing a Float…
Can anyone spot the issue?
Thanks!!
using Julia Version 1.10.3 on Mac