Understanding how ForwardDiff.Duals are propagated

Hello,

I’m trying to understand how ForwardDiff propagates Duals through an ode solver, so how the function defining an ODE system can be overloaded (specialized is another word that could be used?) to handle Duals.
Taking the simple Hodgkin Huxley example from DifferentialEquations.jl, if you’re interested in getting the gradient of the final voltage with respect to conductance parameters then you’d do:

using DifferentialEquations, ForwardDiff

# Potassium ion-channel rate functions
alpha_n(v) = (0.02 * (v - 25.0)) / (1.0 - exp((-1.0 * (v - 25.0)) / 9.0))
beta_n(v) = (-0.002 * (v - 25.0)) / (1.0 - exp((v - 25.0) / 9.0))

# Sodium ion-channel rate functions
alpha_m(v) = (0.182 * (v + 35.0)) / (1.0 - exp((-1.0 * (v + 35.0)) / 9.0))
beta_m(v) = (-0.124 * (v + 35.0)) / (1.0 - exp((v + 35.0) / 9.0))

alpha_h(v) = 0.25 * exp((-1.0 * (v + 90.0)) / 12.0)
beta_h(v) = (0.25 * exp((v + 62.0) / 6.0)) / exp((v + 90.0) / 12.0)

function HH!(du, u, p, t)
    EK, ENa, EL, C, I =  -77.0, 55.0, -65.0, 1, 1
    gK,gNa, gL = p
    v, n, m, h = u

    du[1] = (-(gK * (n^4.0) * (v - EK)) - (gNa * (m^3.0) * h * (v - ENa)) -
             (gL * (v - EL)) + I) / C
    du[2] = (alpha_n(v) * (1.0 - n)) - (beta_n(v) * n)
    du[3] = (alpha_m(v) * (1.0 - m)) - (beta_m(v) * m)
    du[4] = (alpha_h(v) * (1.0 - h)) - (beta_h(v) * h)
end

n_inf(v) = alpha_n(v) / (alpha_n(v) + beta_n(v))
m_inf(v) = alpha_m(v) / (alpha_m(v) + beta_m(v))
h_inf(v) = alpha_h(v) / (alpha_h(v) + beta_h(v))

p = [35.0,40.0, 0.3]
u0 = [-60, n_inf(-60), m_inf(-60), h_inf(-60)]

HHprob = ODEProblem(HH!, u0, (0.0, 1000), p)
sol = solve(HHprob);

function loss(p)
    sol = solve(HHprob, Tsit5(), p=p)
    return sol[end][1]
end

loss(p)
grad = @time ForwardDiff.gradient(loss, p)
@show grad

The following should be equivalent since the custom derivatives I’ve defined below are in fact the derivatives that you’d get in the right hand side of the ode system – and in fact, only in the voltage equation.

# Now TEST CUSTOM ForwardDiff handling of Duals

function myHH!(du, u, p, t)
    EK, ENa, EL, C, I = -77.0, 55.0, -65.0, 1, 1
    gK, gNa, gL = p

    # Unpack the state variables (v: membrane potential, n, m, h: gating variables)
    v, n, m, h = u

    if v isa ForwardDiff.Dual  # Check if we're dealing with Dual numbers (ForwardDiff)
        println("Using Dual numbers for differentiation")
        # Extract primal values (normal numerical values)
        v_val = ForwardDiff.value(v)
        n_val = ForwardDiff.value(n)
        m_val = ForwardDiff.value(m)
        h_val = ForwardDiff.value(h)

        # Compute the primal ODE equations (as before)
        du_primal1 = (-(gK * (n_val^4.0) * (v_val - EK)) -
                      (gNa * (m_val^3.0) * h_val * (v_val - ENa)) -
                      (gL * (v_val - EL)) + I) / C

        du_primal2 = (alpha_n(v_val) * (1.0 - n_val)) - (beta_n(v_val) * n_val)
        du_primal3 = (alpha_m(v_val) * (1.0 - m_val)) - (beta_m(v_val) * m_val)
        du_primal4 = (alpha_h(v_val) * (1.0 - h_val)) - (beta_h(v_val) * h_val)

        # Get the partial derivatives for each state variable
        v_partials = ForwardDiff.partials(v)

        # Calculate the derivatives of the ODE system RHS with respect to the parameters
        custom_deriv_gK = -(n_val^4.0) * (v_val - EK) * v_partials[1]
        custom_deriv_gNa = -(m_val^3.0) * h_val * (v_val - ENa) * v_partials[2]
        custom_deriv_gL = -(v_val - EL) * v_partials[3]

        # Create the dual values by combining primal values and derivatives
        du[1] = ForwardDiff.Dual(du_primal1, (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL))
        du[2] = du_primal2
        du[3] = du_primal3
        du[4] = du_primal4
    else
        # Standard ODE computation when not dealing with Duals (normal floating point numbers)
        du[1] = (-(gK * (n^4.0) * (v - EK)) -
                 (gNa * (m^3.0) * h * (v - ENa)) -
                 (gL * (v - EL)) + I) / C
        du[2] = (alpha_n(v) * (1.0 - n)) - (beta_n(v) * n)
        du[3] = (alpha_m(v) * (1.0 - m)) - (beta_m(v) * m)
        du[4] = (alpha_h(v) * (1.0 - h)) - (beta_h(v) * h)
    end
end

myHHprob = ODEProblem(myHH!, u0, (0.0, 1000), p)
mysol = solve(myHHprob);

function myloss(p)
    sol = solve(myHHprob, Tsit5(), p=p)
    return sol[end][1]
end

myloss(p)
mygrad = @time ForwardDiff.gradient(myloss, p)
@show mygrad

But I’m getting an error in the construction of
du[1]=ForwardDiff.Dual(du_primal1, (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL))

ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, ForwardDiff.Dual{ForwardDiff.Tag{typeof(myloss), Float64}, Float64, 3}, 3})

What’s wrong with passing the tuple (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL) to the second argument of Dual ? I can’t see where this might be forcing a Float…

Can anyone spot the issue?

Thanks!!


using Julia Version 1.10.3 on Mac

du can be dual valued when u is not if it’s a rosenbrock method and needs the derivative w.r.t. t.

Hey Chris,

Sorry I’m not following, is Tsit5 a Rosenbrock method? Is the suggestion to change the solver?

Do you mean the partials should be (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL, deriv_t) ?

Thanks

Are you sure it’s in the construction, i.e. the right hand side? Not in the setindex! which does a convert?

Which type does du have when it fails? Can you look at it? Split up the assignment into

tmp = ForwardDiff.Dual...
println("rhs type ", typeof(tmp))
println("lhs type ", eltype(du))
du[1] = tmp

Oh sorry misread the first time.

First of all, your computed primal value is not correct. You also need to:

gK, gNa, gL = ForwardDiff.primal.(p)

That will fix your current error, which is that your primal is dual valued because the parameters are dual valued.

Next, your dual isn’t correct. ForwardDiff.Dual is not the right constructor since you’re not forcing the correct tag. So that’s going to error with a tag error if you get it correct. ForwardDiff.Dual{ForwardDiff.tag(du[1])}(...) would need to manually push forward the tag.

But then you’re also missing dual terms due to the propagation of the derivatives with respect to the parameters being dropped. So your derivative won’t be correct without adding those terms.

Thanks both.
@sgaure you were right, I had to extract the Tag from the lhs.

After fixing that and incorporating gK, gNa, gL = ForwardDiff.primal.(p) the code reads

function myHH!(du, u, p, t)
    EK, ENa, EL, C, I = -77.0, 55.0, -65.0, 1, 1
    gK, gNa, gL = p

    # Unpack the state variables (v: membrane potential, n, m, h: gating variables)
    v, n, m, h = u

    if p isa ForwardDiff.Dual  # Check if we're dealing with Dual numbers (ForwardDiff)
        println("Using Dual numbers for differentiation")
        # Extract primal values (normal numerical values)
        v_val = ForwardDiff.value(v)
        n_val = ForwardDiff.value(n)
        m_val = ForwardDiff.value(m)
        h_val = ForwardDiff.value(h)

        gK, gNa, gL = ForwardDiff.value.(p)

        # Compute the primal ODE equations (as before)
        du_primal1 = (-(gK * (n_val^4.0) * (v_val - EK)) -
                      (gNa * (m_val^3.0) * h_val * (v_val - ENa)) -
                      (gL * (v_val - EL)) + I) / C

        du_primal2 = (alpha_n(v_val) * (1.0 - n_val)) - (beta_n(v_val) * n_val)
        du_primal3 = (alpha_m(v_val) * (1.0 - m_val)) - (beta_m(v_val) * m_val)
        du_primal4 = (alpha_h(v_val) * (1.0 - h_val)) - (beta_h(v_val) * h_val)

        # Get the partial derivatives for each state variable
        v_partials = ForwardDiff.partials(v)
        n_partials = ForwardDiff.partials(n)
        m_partials = ForwardDiff.partials(m)
        h_partials = ForwardDiff.partials(h)

        # Calculate the derivatives of the ODE system RHS with respect to the parameters
        custom_deriv_gK = -(n_val^4.0) * (v_val - EK) * v_partials[1]
        custom_deriv_gNa = -(m_val^3.0) * h_val * (v_val - ENa) * v_partials[2]
        custom_deriv_gL = -(v_val - EL) * v_partials[3]

         
        # Extract the tag (the first type parameter of the Dual)
        dual_type = typeof(du).parameters[1]
        tag = dual_type.parameters[1]
        tmp = ForwardDiff.Dual{tag}(du_primal1, (custom_deriv_gK, custom_deriv_gNa, custom_deriv_gL))
        println("rhs type ", typeof(tmp))
        println("lhs type ", eltype(du))
        
        du[1] = tmp
        # Create the dual values by combining primal values and derivatives
        du[2] = ForwardDiff.Dual{tag}(du_primal2, (0.0*n_partials[1],0.0*n_partials[2],0.0*n_partials[3]))
        du[3] = ForwardDiff.Dual{tag}(du_primal3, (0.0*m_partials[1],0.0*m_partials[2],0.0*m_partials[3]))
        du[4] = ForwardDiff.Dual{tag}(du_primal4, (0.0*h_partials[1],0.0*h_partials[2],0.0*h_partials[3]))
    else
        # Standard ODE computation when not dealing with Duals (normal floating point numbers)
        du[1] = (-(gK * (n^4.0) * (v - EK)) -
                 (gNa * (m^3.0) * h * (v - ENa)) -
                 (gL * (v - EL)) + I) / C
        du[2] = (alpha_n(v) * (1.0 - n)) - (beta_n(v) * n)
        du[3] = (alpha_m(v) * (1.0 - m)) - (beta_m(v) * m)
        du[4] = (alpha_h(v) * (1.0 - h)) - (beta_h(v) * h)
    end
end

where the other derivatives wrt parameters are not dropped anymore. So this is addressed

and all du[i] have their corresponding partials.

But now calling the gradient with

function myloss(p)
    myHHprob = ODEProblem(myHH!, eltype(p).(u0),(0.0, 1000), p)
    sol = solve(myHHprob, Tsit5())
    return sol[end][1]
end
mygrad = @time ForwardDiff.gradient(myloss, p)

throws

ERROR: LoadError: TypeError: in cfunction, expected Union{}, got a value of type Nothing
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:137 [inlined]
  [2] do_ccall
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:125 [inlined]
  [3] FunctionWrapper
    @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:144 [inlined]
  [4] _call
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:12 [inlined]
  [5] FunctionWrappersWrapper
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10 [inlined]
  [6] ODEFunction
    @ ~/.julia/packages/SciMLBase/SDjaO/src/scimlfunctions.jl:2296 [inlined]
  [7] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/perform_step/low_order_rk_perform_step.jl:799
  [8] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:518
  [9] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:11 [inlined]
 [10] #__solve#799
    @ ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:6 [inlined]
 [11] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:1 [inlined]
 [12] #solve_call#44
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612 [inlined]
 [13] solve_call
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:569 [inlined]
 [14] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080 [inlined]
 [15] solve_up
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1066 [inlined]
 [16] #solve#51
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003 [inlined]
 [17] solve(prob::ODEProblem{…}, args::Tsit5{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:993
 [18] myloss(p::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(myloss), Float64}, Float64, 3}})

What am I missing in myHH! to make the gradient work??

For now do ODEProblem{true, SciMLBase.FullSpecialize}(myHH!, eltype(p).(u0),(0.0, 1000), p)