How to differentiate on ODEProblem (built from other problem) w.r.t. its parameters

Hi everyone,

The basic issue is: differentiating (using ForwardDiff.jl) an ODEProblem from DifferentialEquations.jl with respect to parameters, where the ODEProblem is built from another ODEProblem.

My problem arises in this code:

function testfunc(p, costprob)
    prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
    return solve(prob, save_everystep=false)[end][end]
end

function cost_grad(p_, nomprob)
    testpertprob = param_prob(p_, nomprob)
    costprob = make_cost_problem(nomprob, testpertprob)[1]
    return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end

The full code is below. I apologise for it being a lot, but I couldn’t figure out how to prune it since I don’t know where the issue lies.

using ParameterizedFunctions
using DiffEqBiological
using DifferentialEquations
using DiffEqSensitivity
using Random
using ForwardDiff
using Optim
using Flux 
using DiffEqFlux
using LinearAlgebra
using Statistics
using Plots
using LaTeXStrings

We take the Connor-Stevens model (a model of neuronal dynamics) as our original problem:

This is where we create the original problem:

m_shift = -5.3
h_shift = -12
n_shift = -4.3

# Gating equations
# Hodgkin-Huxley with shifts in temperature factor
α_m(V) = -.1 * (V + 35 + m_shift) / (exp(-(V + 35 + m_shift) / 10) - 1)
β_m(V) = 4 * exp(-(V + 60 + m_shift) / 18)
m_∞(V) = α_m(V) / (α_m(V) + β_m(V))
τ_m(V) = 1 / (3.8 * (α_m(V) + β_m(V)))
α_h(V) = .07 * exp(-(V + 60 + h_shift) / 20)
β_h(V) = 1 / (1 + exp(-(V + 30 + h_shift) / 10))
h_∞(V) = α_h(V) / (α_h(V) + β_h(V))
τ_h(V) = 1 / (3.8 * (α_h(V) + β_h(V)))
α_n(V) = -.01 * (V + 50 + n_shift) / (exp(-(V + 50 + n_shift) / 10) - 1)
β_n(V) = .125 * exp(-(V + 60 + n_shift) / 80)
n_∞(V) = α_n(V) / (α_n(V) + β_n(V))
# τ_n is doubled
τ_n(V) = 2 / (3.8 * (α_n(V) + β_n(V)))
# The A current
a_∞(V) = (.0761 * exp((V + 94.22) / 31.84) / (1 + exp((V + 1.17) / 28.93)))^(.3333)
τ_a(V) = .3632 + 1.158 / (1 + exp((V + 55.96) / 20.12))
b_∞(V) = 1 / (1 + exp((V + 53.3) / 14.54))^4
τ_b(V) = 1.24 + 2.678 / (1 + exp((V + 50) / 16.027))

# The ion currents
I_Na(m, h, V) = g_Na * m^3 * h * (V - E_Na)
I_K(n, V) = g_K * n^4 * (V - E_K)
I_A(a, b, V) = g_A * b * a^3 * (V - E_A)
I_l(V) = g_l *  (V - E_l)

# The applied current
function I_app(t)
    if 50 ≤ t ≤ 200
        return 15.
    else
        return 0.
    end
end

vf_lean = @ode_def begin
    dV = - (g_Na * m^3 * h * (V - E_Na) + g_K * n^4 * (V - E_K) + g_A * b * a^3 * (V - E_A)
    + g_l *  (V - E_l)) + I_app(t)
    dm = (m_∞(V) - m) / (τ_m(V))
    dh = (h_∞(V) - h) / (τ_h(V))
    dn = (n_∞(V) - n) / (τ_n(V))
    da = (a_∞(V) - a) / (τ_a(V))
    db = (b_∞(V) - b) / (τ_b(V))
end g_Na g_K g_A g_l E_Na E_K E_A E_l

p = [120., 20., 47.7, 0.3, 55., -72., -75., -17.]

# Initial conditions
V0 = -74.
m0 = 0.004
h0 = 0.986
n0 = 0.1
a0 = 0.509
b0 = 0.426
u0 = [V0; m0; h0; n0; a0; b0]

We define a function that perturbs the parameters slightly:

function perturb_params(_p, δ::Float64)
    neg_idxs = findall(_p .< 0)  # Find indices of negative elements
    new_p = copy(_p)
    new_p[neg_idxs] = - new_p[neg_idxs]  # Prevent the log taking negative arguments
    perturbed_p = exp.(log.(new_p) + δ * randn(length(new_p)))
    perturbed_p[neg_idxs] = - perturbed_p[neg_idxs]  # Restore the minus sign
    return perturbed_p
end

We then create a function which returns the old and new problems, and also integrates a cost:

function make_cost_problem(nomprob::ODEProblem, pertprob::ODEProblem)
    # Make a new problem consisting of the previous two, plus cost integrand
    dim1u, dim2u = [length(x) for x in (nomprob.u0, pertprob.u0)]
    nom_p = nomprob.p

    new_vf = function (u, p, t)
        I_pin = (nomprob.f(u[1:dim1u], nom_p, t)[1] - pertprob.f(u[dim1u+1:dim1u+dim2u], p, t)[1])
        I_pin_term = zeros(length(pertprob.u0))
        I_pin_term[1] = I_pin
        
        vcat(nomprob.f(u[1:dim1u], nom_p, t),
            pertprob.f(u[dim1u+1:dim1u+dim2u] .+ I_pin_term, p, t),
            (I_pin)^2)
    end

    new_p = pertprob.p
    new_u0 = vcat(nomprob.u0, pertprob.u0, 0.)
    cost_problem = ODEProblem(new_vf, new_u0, nomprob.tspan, new_p)
    return cost_problem, (dim1u, dim2u)
end

function get_cost(prob1::ODEProblem, prob2::ODEProblem)
    probnew = make_cost_problem(prob1, prob2)[1]
    sol_at_endpoint = solve(probnew, save_everystep=false)[end][end]
    return cost = sol_at_endpoint
end

function testfunc(p, costprob)
    prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
    return solve(prob, save_everystep=false)[end][end]
end

function cost_grad(p_, nomprob)
    testpertprob = param_prob(p_, nomprob)
    costprob = make_cost_problem(nomprob, testpertprob)[1]
    return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end

function cost_hess(_p,  nomprob)
    ToDiff = params -> cost_grad(params, nomprob)
    return ForwardDiff.jacobian(ToDiff, _p)
end

This new function make_cost_problem provides problems which are solvable, but the function to take the gradient fails.

If anyone would be able to shed some light on this I’d be very grateful. Thanks!

See https://docs.juliadiffeq.org/latest/analysis/sensitivity/ . Just use concrete_solve and sensealg=ForwardDiffAdjoint(), and then Zygote.gradient will use ForwardDiff internally.

Otherwise, please post your error message since without the error message it’s much harder to debug.