Hi everyone,
The basic issue is: differentiating (using ForwardDiff.jl
) an ODEProblem
from DifferentialEquations.jl
with respect to parameters, where the ODEProblem
is built from another ODEProblem
.
My problem arises in this code:
function testfunc(p, costprob)
prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
return solve(prob, save_everystep=false)[end][end]
end
function cost_grad(p_, nomprob)
testpertprob = param_prob(p_, nomprob)
costprob = make_cost_problem(nomprob, testpertprob)[1]
return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end
The full code is below. I apologise for it being a lot, but I couldn’t figure out how to prune it since I don’t know where the issue lies.
using ParameterizedFunctions
using DiffEqBiological
using DifferentialEquations
using DiffEqSensitivity
using Random
using ForwardDiff
using Optim
using Flux
using DiffEqFlux
using LinearAlgebra
using Statistics
using Plots
using LaTeXStrings
We take the Connor-Stevens model (a model of neuronal dynamics) as our original problem:
This is where we create the original problem:
m_shift = -5.3
h_shift = -12
n_shift = -4.3
# Gating equations
# Hodgkin-Huxley with shifts in temperature factor
α_m(V) = -.1 * (V + 35 + m_shift) / (exp(-(V + 35 + m_shift) / 10) - 1)
β_m(V) = 4 * exp(-(V + 60 + m_shift) / 18)
m_∞(V) = α_m(V) / (α_m(V) + β_m(V))
τ_m(V) = 1 / (3.8 * (α_m(V) + β_m(V)))
α_h(V) = .07 * exp(-(V + 60 + h_shift) / 20)
β_h(V) = 1 / (1 + exp(-(V + 30 + h_shift) / 10))
h_∞(V) = α_h(V) / (α_h(V) + β_h(V))
τ_h(V) = 1 / (3.8 * (α_h(V) + β_h(V)))
α_n(V) = -.01 * (V + 50 + n_shift) / (exp(-(V + 50 + n_shift) / 10) - 1)
β_n(V) = .125 * exp(-(V + 60 + n_shift) / 80)
n_∞(V) = α_n(V) / (α_n(V) + β_n(V))
# τ_n is doubled
τ_n(V) = 2 / (3.8 * (α_n(V) + β_n(V)))
# The A current
a_∞(V) = (.0761 * exp((V + 94.22) / 31.84) / (1 + exp((V + 1.17) / 28.93)))^(.3333)
τ_a(V) = .3632 + 1.158 / (1 + exp((V + 55.96) / 20.12))
b_∞(V) = 1 / (1 + exp((V + 53.3) / 14.54))^4
τ_b(V) = 1.24 + 2.678 / (1 + exp((V + 50) / 16.027))
# The ion currents
I_Na(m, h, V) = g_Na * m^3 * h * (V - E_Na)
I_K(n, V) = g_K * n^4 * (V - E_K)
I_A(a, b, V) = g_A * b * a^3 * (V - E_A)
I_l(V) = g_l * (V - E_l)
# The applied current
function I_app(t)
if 50 ≤ t ≤ 200
return 15.
else
return 0.
end
end
vf_lean = @ode_def begin
dV = - (g_Na * m^3 * h * (V - E_Na) + g_K * n^4 * (V - E_K) + g_A * b * a^3 * (V - E_A)
+ g_l * (V - E_l)) + I_app(t)
dm = (m_∞(V) - m) / (τ_m(V))
dh = (h_∞(V) - h) / (τ_h(V))
dn = (n_∞(V) - n) / (τ_n(V))
da = (a_∞(V) - a) / (τ_a(V))
db = (b_∞(V) - b) / (τ_b(V))
end g_Na g_K g_A g_l E_Na E_K E_A E_l
p = [120., 20., 47.7, 0.3, 55., -72., -75., -17.]
# Initial conditions
V0 = -74.
m0 = 0.004
h0 = 0.986
n0 = 0.1
a0 = 0.509
b0 = 0.426
u0 = [V0; m0; h0; n0; a0; b0]
We define a function that perturbs the parameters slightly:
function perturb_params(_p, δ::Float64)
neg_idxs = findall(_p .< 0) # Find indices of negative elements
new_p = copy(_p)
new_p[neg_idxs] = - new_p[neg_idxs] # Prevent the log taking negative arguments
perturbed_p = exp.(log.(new_p) + δ * randn(length(new_p)))
perturbed_p[neg_idxs] = - perturbed_p[neg_idxs] # Restore the minus sign
return perturbed_p
end
We then create a function which returns the old and new problems, and also integrates a cost:
function make_cost_problem(nomprob::ODEProblem, pertprob::ODEProblem)
# Make a new problem consisting of the previous two, plus cost integrand
dim1u, dim2u = [length(x) for x in (nomprob.u0, pertprob.u0)]
nom_p = nomprob.p
new_vf = function (u, p, t)
I_pin = (nomprob.f(u[1:dim1u], nom_p, t)[1] - pertprob.f(u[dim1u+1:dim1u+dim2u], p, t)[1])
I_pin_term = zeros(length(pertprob.u0))
I_pin_term[1] = I_pin
vcat(nomprob.f(u[1:dim1u], nom_p, t),
pertprob.f(u[dim1u+1:dim1u+dim2u] .+ I_pin_term, p, t),
(I_pin)^2)
end
new_p = pertprob.p
new_u0 = vcat(nomprob.u0, pertprob.u0, 0.)
cost_problem = ODEProblem(new_vf, new_u0, nomprob.tspan, new_p)
return cost_problem, (dim1u, dim2u)
end
function get_cost(prob1::ODEProblem, prob2::ODEProblem)
probnew = make_cost_problem(prob1, prob2)[1]
sol_at_endpoint = solve(probnew, save_everystep=false)[end][end]
return cost = sol_at_endpoint
end
function testfunc(p, costprob)
prob = remake(costprob; u0=convert.(eltype(p), costprob.u0), p=p)
return solve(prob, save_everystep=false)[end][end]
end
function cost_grad(p_, nomprob)
testpertprob = param_prob(p_, nomprob)
costprob = make_cost_problem(nomprob, testpertprob)[1]
return ForwardDiff.gradient(p -> testfunc(p, costprob), p_)
end
function cost_hess(_p, nomprob)
ToDiff = params -> cost_grad(params, nomprob)
return ForwardDiff.jacobian(ToDiff, _p)
end
This new function make_cost_problem
provides problems which are solvable, but the function to take the gradient fails.
If anyone would be able to shed some light on this I’d be very grateful. Thanks!