Gradient-based parameter estimation for Catalyst.jl models

Hi, I am trying to estimate parameters/initial conditions in an ODE system built with Catalyst.jl. I first find the steady-state using a callback, and then in a separate simulation I add some input to get my actual desired simulation result.

I have gotten this to work with the ADAM optimizer, but I would like to use a gradient-based approach. However, when using gradient-based approaches in Optim.jl, I get the Warnings: (1) First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN and then (2) Instability detected, aborting. I have tried other gradient-based algorithms and cannot seem to find one that works for the way I have described this problem.

A MWE that reproduces the problem is found below:

using Catalyst, DifferentialEquations, Plots, Optimization, OptimizationOptimisers, OptimizationOptimJL, ForwardDiff

function allDerivPass(integrator, abstol, reltol, min_t)
    if DiffEqBase.isinplace(integrator.sol.prob)
        testval = first(get_tmp_cache(integrator))
        DiffEqBase.get_du!(testval, integrator)
        if typeof(integrator.sol.prob) <: DiffEqBase.DiscreteProblem
            @. testval = testval - integrator.u
        end
    else
        testval = get_du(integrator)
        if typeof(integrator.sol.prob) <: DiffEqBase.DiscreteProblem
            testval = testval - integrator.u
        end
    end

    if typeof(integrator.u) <: Array
        any(abs(d) > abstol && abs(d) > reltol * abs(u)
            for (d, abstol, reltol, u) in zip(testval, Iterators.cycle(abstol),
            Iterators.cycle(reltol), integrator.u)) &&
            (return false)
    else
        any((abs.(testval) .> abstol) .& (abs.(testval) .> reltol .* abs.(integrator.u))) &&
            (return false)
    end

    return true
end

rn = @reaction_network Endothelin begin 
    @species u1(t) u2(t) Monomer(t) Polymer(t) End1(t) End2(t) x1(t) x4(t) x6(t)
    @parameters p5 p7 p8 p9 p10 p11 p16 p22 p23
	(p1, p2), u1 + x1 <--> x2
   p3, x2 --> u1 + x3
   p4, x3 --> x1
	(p5, p6), x3 + x4 <--> x5
   p7*End1, Monomer --> Polymer
   p8*End1, Polymer => Monomer
   p9*End2, Monomer --> Polymer
   p10*End2, Polymer => Monomer
   p11, 2*Monomer --> 2*Polymer + End1 + End2
   p12/(1 + p13*Polymer/End2), End1 + End2 --> ∅
   p14*Polymer, ∅ --> End1 + End2
   p15*exp(-((log10(x1)) - log10(p16)/p17)^2)*Polymer, ∅ --> End1 + End2
   (p18, p19), x6 + End1 <--> x7
   (p20, p21), u2 + End1 <--> u2_End1
   p22*u2_End1, Monomer --> Polymer
   p23, u2 --> ∅
end 

p_list = [:p1 => 10, :p2 => 1, :p3 => 2, :p4 => 0.01, :p5 => 10, :p6 => 10, :p7 => 10, :p8 => 1, :p9 => 1, :p10 => 10, :p11 => 0.0001, :p12 => 1, :p13 => 0.01, :p14 => 1e-6, :p15 => 1e-1, :p16 => 0.4, :p17 => 0.5, :p18 => 10, :p19 => 0.01, :p20 => 10, :p21 => 0.01, :p22 => 100, :p23 => 0.1]
reaction_params = parameters(rn)
p0 = zeros(length(reaction_params))
for i ∈ eachindex(reaction_params)
	for j ∈ eachindex(p_list)
		if string(p_list[j][1]) == string(reaction_params[i])
			p0[i] = p_list[j][2]
			break
		end
	end
end

u_list = [:u1 => 0, :u2 => 0, :u2_End1 => 0, :x1 => 10, :x2 => 0, :x3 => 0, :x4 => 30, :x5 => 0, :x6 => 1, :x7 => 0, :Monomer => 50, :Polymer => 50, :End1 => 2, :End2 => 2]
reaction_states = states(rn)
u0 = zeros(length(reaction_states))
for i ∈ eachindex(reaction_states)
	for j ∈ eachindex(u_list)
		if string(u_list[j][1]) == string(reaction_states[i])[1:end-3]
			u0[i] = u_list[j][2]
			break
		end
	end
end

abstol = 1e-7
reltol = 1e-5
min_t = nothing
condition_SS = (u, t, integrator) -> allDerivPass(integrator, abstol, reltol, min_t)
affect_SS! = (integrator) -> terminate!(integrator)
cb_SS = DiscreteCallback(condition_SS, affect_SS!; save_positions = (true, false))

orig_prob = ODEProblem(rn, u0, (0,Inf), p0)
sol_orig_prob = solve(orig_prob, Rodas5(), callback=cb_SS)

sample_times = 60 .* collect(0:10:100)
stim_prob = remake(orig_prob; u0 = vcat(10, sol_orig_prob.u[end][2:end]), tspan = (0, 60*100), p = p0)
sol_stim_prob = solve(stim_prob, Rodas5(), saveat=sample_times)

plot(sol_stim_prob, idxs=12)

data1 = 50 .* [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
data2 = 50 .* [1, 1.2, 1.3, 1.4, 1.4, 1.4, 1.4, 1.4, 1.3, 1.2, 1.1]
data = vcat(data1, data2)

function loss(x)

	alpha = [0.01, 100.0]

	#alg = RadauIIA5()
	alg = Rodas5()

	ss_prob = remake(orig_prob, u0=vcat(0.0, 0.0, 50.0, 50.0, 2.0, 2.0, exp.(x[1:3]), zeros(Float64, 5)), tspan=(0,Inf), p = vcat(exp.(x[4:end]), p0[9:end]))
	sol_ss_prob = solve(ss_prob, alg, callback=cb_SS)

	# u1
	inh_prob = remake(orig_prob; u0 = vcat(1.0, sol_ss_prob.u[end][2:end]), tspan = (0, 6000), p = vcat(exp.(x[4:end]), p0[9:end]))
	sol_inh_prob = solve(inh_prob, alg; saveat=sample_times)

	# u2
	stim_inh_prob = remake(orig_prob; u0 = vcat(1.0, 2.0, sol_ss_prob.u[end][3:end]), tspan=(0,6000), p = vcat(exp.(x[4:end]), p0[9:end]))
	sol_stim_inh_prob = solve(stim_inh_prob, alg; saveat=sample_times)
	
    y_model = vcat(sol_inh_prob[4,:], sol_stim_inh_prob[4,:])

	loss =  (1.0/(length(data)))*sum(abs2, (y_model .- data)) + 
			alpha[1]*(1.0/(length(x)))*(sum(abs, x .- log.(vcat(u0[7:9],p0[1:8])))) + 
			alpha[2]*(sum(abs, (sol_ss_prob.u[end][4] .- 50.0)))

	return loss, [sol_inh_prob, sol_stim_inh_prob]
end

callback_opt = function (state, l, sol)
	display(l)

	plt = plot(layout=(2,1), size=(1200,1200))
	plot!(plt, 0:10:100, sol[1][4,:], ylim = (35, 75),  xlim=(-5, 105), label="", subplot=1)
	scatter!(plt, 0:10:100, data1, label = "", subplot=1)
    plot!(plt, 0:10:100, sol[2][4,:], ylim = (35, 75),  xlim=(-5, 105), label="", subplot=2)
	scatter!(plt, 0:10:100, data2, label = "", subplot=2)
	
	display(plt)

	# Tell Optimization.solve to not halt the optimization. If return true, then
	# optimization stops.
	return false
end


adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((p,alpha) -> loss(p), adtype)

x00 = vcat(u0[7:9],p0[1:8])
optprob = Optimization.OptimizationProblem(optf, log.(x00), nothing)
	
opt1_sol = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.01), callback=callback_opt, maxiters=10)
opt2_sol = Optimization.solve(optprob, Optim.BFGS(), callback=callback_opt, maxiters=10)
  1. A what parameters exactly does it exit with NaNs?
  2. With what tolerances were you running the solvers, and did you try lowering the tolerances?