Hi, I am trying to estimate parameters/initial conditions in an ODE system built with Catalyst.jl. I first find the steady-state using a callback, and then in a separate simulation I add some input to get my actual desired simulation result.
I have gotten this to work with the ADAM optimizer, but I would like to use a gradient-based approach. However, when using gradient-based approaches in Optim.jl, I get the Warnings: (1) First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN and then (2) Instability detected, aborting. I have tried other gradient-based algorithms and cannot seem to find one that works for the way I have described this problem.
A MWE that reproduces the problem is found below:
using Catalyst, DifferentialEquations, Plots, Optimization, OptimizationOptimisers, OptimizationOptimJL, ForwardDiff
function allDerivPass(integrator, abstol, reltol, min_t)
if DiffEqBase.isinplace(integrator.sol.prob)
testval = first(get_tmp_cache(integrator))
DiffEqBase.get_du!(testval, integrator)
if typeof(integrator.sol.prob) <: DiffEqBase.DiscreteProblem
@. testval = testval - integrator.u
end
else
testval = get_du(integrator)
if typeof(integrator.sol.prob) <: DiffEqBase.DiscreteProblem
testval = testval - integrator.u
end
end
if typeof(integrator.u) <: Array
any(abs(d) > abstol && abs(d) > reltol * abs(u)
for (d, abstol, reltol, u) in zip(testval, Iterators.cycle(abstol),
Iterators.cycle(reltol), integrator.u)) &&
(return false)
else
any((abs.(testval) .> abstol) .& (abs.(testval) .> reltol .* abs.(integrator.u))) &&
(return false)
end
return true
end
rn = @reaction_network Endothelin begin
@species u1(t) u2(t) Monomer(t) Polymer(t) End1(t) End2(t) x1(t) x4(t) x6(t)
@parameters p5 p7 p8 p9 p10 p11 p16 p22 p23
(p1, p2), u1 + x1 <--> x2
p3, x2 --> u1 + x3
p4, x3 --> x1
(p5, p6), x3 + x4 <--> x5
p7*End1, Monomer --> Polymer
p8*End1, Polymer => Monomer
p9*End2, Monomer --> Polymer
p10*End2, Polymer => Monomer
p11, 2*Monomer --> 2*Polymer + End1 + End2
p12/(1 + p13*Polymer/End2), End1 + End2 --> ∅
p14*Polymer, ∅ --> End1 + End2
p15*exp(-((log10(x1)) - log10(p16)/p17)^2)*Polymer, ∅ --> End1 + End2
(p18, p19), x6 + End1 <--> x7
(p20, p21), u2 + End1 <--> u2_End1
p22*u2_End1, Monomer --> Polymer
p23, u2 --> ∅
end
p_list = [:p1 => 10, :p2 => 1, :p3 => 2, :p4 => 0.01, :p5 => 10, :p6 => 10, :p7 => 10, :p8 => 1, :p9 => 1, :p10 => 10, :p11 => 0.0001, :p12 => 1, :p13 => 0.01, :p14 => 1e-6, :p15 => 1e-1, :p16 => 0.4, :p17 => 0.5, :p18 => 10, :p19 => 0.01, :p20 => 10, :p21 => 0.01, :p22 => 100, :p23 => 0.1]
reaction_params = parameters(rn)
p0 = zeros(length(reaction_params))
for i ∈ eachindex(reaction_params)
for j ∈ eachindex(p_list)
if string(p_list[j][1]) == string(reaction_params[i])
p0[i] = p_list[j][2]
break
end
end
end
u_list = [:u1 => 0, :u2 => 0, :u2_End1 => 0, :x1 => 10, :x2 => 0, :x3 => 0, :x4 => 30, :x5 => 0, :x6 => 1, :x7 => 0, :Monomer => 50, :Polymer => 50, :End1 => 2, :End2 => 2]
reaction_states = states(rn)
u0 = zeros(length(reaction_states))
for i ∈ eachindex(reaction_states)
for j ∈ eachindex(u_list)
if string(u_list[j][1]) == string(reaction_states[i])[1:end-3]
u0[i] = u_list[j][2]
break
end
end
end
abstol = 1e-7
reltol = 1e-5
min_t = nothing
condition_SS = (u, t, integrator) -> allDerivPass(integrator, abstol, reltol, min_t)
affect_SS! = (integrator) -> terminate!(integrator)
cb_SS = DiscreteCallback(condition_SS, affect_SS!; save_positions = (true, false))
orig_prob = ODEProblem(rn, u0, (0,Inf), p0)
sol_orig_prob = solve(orig_prob, Rodas5(), callback=cb_SS)
sample_times = 60 .* collect(0:10:100)
stim_prob = remake(orig_prob; u0 = vcat(10, sol_orig_prob.u[end][2:end]), tspan = (0, 60*100), p = p0)
sol_stim_prob = solve(stim_prob, Rodas5(), saveat=sample_times)
plot(sol_stim_prob, idxs=12)
data1 = 50 .* [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
data2 = 50 .* [1, 1.2, 1.3, 1.4, 1.4, 1.4, 1.4, 1.4, 1.3, 1.2, 1.1]
data = vcat(data1, data2)
function loss(x)
alpha = [0.01, 100.0]
#alg = RadauIIA5()
alg = Rodas5()
ss_prob = remake(orig_prob, u0=vcat(0.0, 0.0, 50.0, 50.0, 2.0, 2.0, exp.(x[1:3]), zeros(Float64, 5)), tspan=(0,Inf), p = vcat(exp.(x[4:end]), p0[9:end]))
sol_ss_prob = solve(ss_prob, alg, callback=cb_SS)
# u1
inh_prob = remake(orig_prob; u0 = vcat(1.0, sol_ss_prob.u[end][2:end]), tspan = (0, 6000), p = vcat(exp.(x[4:end]), p0[9:end]))
sol_inh_prob = solve(inh_prob, alg; saveat=sample_times)
# u2
stim_inh_prob = remake(orig_prob; u0 = vcat(1.0, 2.0, sol_ss_prob.u[end][3:end]), tspan=(0,6000), p = vcat(exp.(x[4:end]), p0[9:end]))
sol_stim_inh_prob = solve(stim_inh_prob, alg; saveat=sample_times)
y_model = vcat(sol_inh_prob[4,:], sol_stim_inh_prob[4,:])
loss = (1.0/(length(data)))*sum(abs2, (y_model .- data)) +
alpha[1]*(1.0/(length(x)))*(sum(abs, x .- log.(vcat(u0[7:9],p0[1:8])))) +
alpha[2]*(sum(abs, (sol_ss_prob.u[end][4] .- 50.0)))
return loss, [sol_inh_prob, sol_stim_inh_prob]
end
callback_opt = function (state, l, sol)
display(l)
plt = plot(layout=(2,1), size=(1200,1200))
plot!(plt, 0:10:100, sol[1][4,:], ylim = (35, 75), xlim=(-5, 105), label="", subplot=1)
scatter!(plt, 0:10:100, data1, label = "", subplot=1)
plot!(plt, 0:10:100, sol[2][4,:], ylim = (35, 75), xlim=(-5, 105), label="", subplot=2)
scatter!(plt, 0:10:100, data2, label = "", subplot=2)
display(plt)
# Tell Optimization.solve to not halt the optimization. If return true, then
# optimization stops.
return false
end
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((p,alpha) -> loss(p), adtype)
x00 = vcat(u0[7:9],p0[1:8])
optprob = Optimization.OptimizationProblem(optf, log.(x00), nothing)
opt1_sol = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.01), callback=callback_opt, maxiters=10)
opt2_sol = Optimization.solve(optprob, Optim.BFGS(), callback=callback_opt, maxiters=10)