How to differentiate on ODEProblem (built from other problem) w.r.t. its parameters

Hi everyone,

The basic issue is: differentiating (using `ForwardDiff.jl`) an `ODEProblem` from `DifferentialEquations.jl` with respect to parameters, where the `ODEProblem` is built from another `ODEProblem`.

My problem arises in this code:

``````function testfunc(p, costprob)
prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
return solve(prob, save_everystep=false)[end][end]
end

testpertprob = param_prob(p_, nomprob)
costprob = make_cost_problem(nomprob, testpertprob)[1]
return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end
``````

The full code is below. I apologise for it being a lot, but I couldn’t figure out how to prune it since I don’t know where the issue lies.

``````using ParameterizedFunctions
using DiffEqBiological
using DifferentialEquations
using DiffEqSensitivity
using Random
using ForwardDiff
using Optim
using Flux
using DiffEqFlux
using LinearAlgebra
using Statistics
using Plots
using LaTeXStrings
``````

We take the Connor-Stevens model (a model of neuronal dynamics) as our original problem:

This is where we create the original problem:

``````m_shift = -5.3
h_shift = -12
n_shift = -4.3

# Gating equations
# Hodgkin-Huxley with shifts in temperature factor
α_m(V) = -.1 * (V + 35 + m_shift) / (exp(-(V + 35 + m_shift) / 10) - 1)
β_m(V) = 4 * exp(-(V + 60 + m_shift) / 18)
m_∞(V) = α_m(V) / (α_m(V) + β_m(V))
τ_m(V) = 1 / (3.8 * (α_m(V) + β_m(V)))
α_h(V) = .07 * exp(-(V + 60 + h_shift) / 20)
β_h(V) = 1 / (1 + exp(-(V + 30 + h_shift) / 10))
h_∞(V) = α_h(V) / (α_h(V) + β_h(V))
τ_h(V) = 1 / (3.8 * (α_h(V) + β_h(V)))
α_n(V) = -.01 * (V + 50 + n_shift) / (exp(-(V + 50 + n_shift) / 10) - 1)
β_n(V) = .125 * exp(-(V + 60 + n_shift) / 80)
n_∞(V) = α_n(V) / (α_n(V) + β_n(V))
# τ_n is doubled
τ_n(V) = 2 / (3.8 * (α_n(V) + β_n(V)))
# The A current
a_∞(V) = (.0761 * exp((V + 94.22) / 31.84) / (1 + exp((V + 1.17) / 28.93)))^(.3333)
τ_a(V) = .3632 + 1.158 / (1 + exp((V + 55.96) / 20.12))
b_∞(V) = 1 / (1 + exp((V + 53.3) / 14.54))^4
τ_b(V) = 1.24 + 2.678 / (1 + exp((V + 50) / 16.027))

# The ion currents
I_Na(m, h, V) = g_Na * m^3 * h * (V - E_Na)
I_K(n, V) = g_K * n^4 * (V - E_K)
I_A(a, b, V) = g_A * b * a^3 * (V - E_A)
I_l(V) = g_l *  (V - E_l)

# The applied current
function I_app(t)
if 50 ≤ t ≤ 200
return 15.
else
return 0.
end
end

vf_lean = @ode_def begin
dV = - (g_Na * m^3 * h * (V - E_Na) + g_K * n^4 * (V - E_K) + g_A * b * a^3 * (V - E_A)
+ g_l *  (V - E_l)) + I_app(t)
dm = (m_∞(V) - m) / (τ_m(V))
dh = (h_∞(V) - h) / (τ_h(V))
dn = (n_∞(V) - n) / (τ_n(V))
da = (a_∞(V) - a) / (τ_a(V))
db = (b_∞(V) - b) / (τ_b(V))
end g_Na g_K g_A g_l E_Na E_K E_A E_l

p = [120., 20., 47.7, 0.3, 55., -72., -75., -17.]

# Initial conditions
V0 = -74.
m0 = 0.004
h0 = 0.986
n0 = 0.1
a0 = 0.509
b0 = 0.426
u0 = [V0; m0; h0; n0; a0; b0]
``````

We define a function that perturbs the parameters slightly:

``````function perturb_params(_p, δ::Float64)
neg_idxs = findall(_p .< 0)  # Find indices of negative elements
new_p = copy(_p)
new_p[neg_idxs] = - new_p[neg_idxs]  # Prevent the log taking negative arguments
perturbed_p = exp.(log.(new_p) + δ * randn(length(new_p)))
perturbed_p[neg_idxs] = - perturbed_p[neg_idxs]  # Restore the minus sign
return perturbed_p
end
``````

We then create a function which returns the old and new problems, and also integrates a cost:

``````function make_cost_problem(nomprob::ODEProblem, pertprob::ODEProblem)
# Make a new problem consisting of the previous two, plus cost integrand
dim1u, dim2u = [length(x) for x in (nomprob.u0, pertprob.u0)]
nom_p = nomprob.p

new_vf = function (u, p, t)
I_pin = (nomprob.f(u[1:dim1u], nom_p, t)[1] - pertprob.f(u[dim1u+1:dim1u+dim2u], p, t)[1])
I_pin_term = zeros(length(pertprob.u0))
I_pin_term[1] = I_pin

vcat(nomprob.f(u[1:dim1u], nom_p, t),
pertprob.f(u[dim1u+1:dim1u+dim2u] .+ I_pin_term, p, t),
(I_pin)^2)
end

new_p = pertprob.p
new_u0 = vcat(nomprob.u0, pertprob.u0, 0.)
cost_problem = ODEProblem(new_vf, new_u0, nomprob.tspan, new_p)
return cost_problem, (dim1u, dim2u)
end

function get_cost(prob1::ODEProblem, prob2::ODEProblem)
probnew = make_cost_problem(prob1, prob2)[1]
sol_at_endpoint = solve(probnew, save_everystep=false)[end][end]
return cost = sol_at_endpoint
end

function testfunc(p, costprob)
prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
return solve(prob, save_everystep=false)[end][end]
end

testpertprob = param_prob(p_, nomprob)
costprob = make_cost_problem(nomprob, testpertprob)[1]
return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end

function cost_hess(_p,  nomprob)
ToDiff = params -> cost_grad(params, nomprob)
return ForwardDiff.jacobian(ToDiff, _p)
end
``````

This new function `make_cost_problem` provides problems which are solvable, but the function to take the gradient fails.

If anyone would be able to shed some light on this I’d be very grateful. Thanks!

See https://docs.juliadiffeq.org/latest/analysis/sensitivity/ . Just use `concrete_solve` and `sensealg=ForwardDiffAdjoint()`, and then Zygote.gradient will use ForwardDiff internally.

Otherwise, please post your error message since without the error message it’s much harder to debug.