TLDR - Is it possible to use both state / observed variables from ODESolution in a likelihood, without concatenating them ex post or using mutation?
Maybe there is something dead-simple I’m missing!
I’m starting to convert some models in continuous time into the Julia SciML ecosystem (which is fantastic). For my application it is often convenient to describe the system’s behavioral equations using both differential and algebraic equations, with estimated model parameters contained in both, plus some identities/constraints. So the system is a DAE (medium/large), but the likelihood relies on both “state” and “observed” variables in the ModelingToolkit.jl terminology.
This runs me into the standard mutation problem with Zygote. I am practicing tearing / simplifying the system, then concatonating these variables after the integration step to compute the likelihood, but this seems like a poor idea since this involves mutation and forces reliance only on ForwardDiff (undesirable for large P). I suppose I could take total time derivatives of all the observed algebraic equations and use only the expanded state variables in the likelihood (and reverse-mode AD). Is there a simple way to avoid this and to access both the state/observed variables from the ODESolution in a way compatible with reverse-mode AD?
My apologies if this simply relitigates the well-known issue of mutation with Zygote, that is not my intent (though I have tried to avoid this using Buffer unsuccessfully). I hoped that this was a use-case were an easier solution was available and that I am misunderstanding some basic package functionality regarding differentiability.
Here is a simple system illustrating the point:
using ModelingToolkit, DifferentialEquations
using Distributions, Optimization
# define symbolic system
@parameters α, ρ, β, θ
@variables t, u₁(t), u₂(t), u₃(t), z(t)
D = Differential(t)
eqs = [D(u₁) ~ α*(u₂-u₁),
D(u₂) ~ u₁ * (ρ - u₃) - u₂,
D(u₃) ~ u₁ * u₂ - β*z,
z ~ θ*u₃] # <-- also have data for this, absence from likelihood generates weak identification
@named lorenz1 = ODESystem(eqs)
# weak identification : β vs. θ₁
simple_sys = structural_simplify(lorenz1)
p = [10, 28 , 1/6, 2]
tspan = (0.0,3.0)
tseq = collect(tspan[1]:0.1:tspan[2])
u0 = [1.0,0.0,0.0]
prob = ODEProblem(simple_sys,u0,tspan,p;)
sol = solve(prob,Tsit5(),saveat=collect(0.0:0.1:3))
# Generate data including z
# concatonate obs + state
obs_symb = states(lorenz1)
traj = sol[obs_symb]
N, T = (length(obs_symb),length(t_seq))
data_cat = [traj[t][i] + 2.0 * randn() for i in 1:N, t in 1:T]
p0 = vcat(prob.p,fill(2.2,N)) # init param
function loglik(p)
# re index param vector
θ = eltype(p).(p[1:length(prob.p)])
sigma = eltype(p).(p[(length(prob.p)+1):(N+length(prob.p))])
# solve simp/torn problem
tmp_prob = remake(prob,u0=eltype(θ).(prob.u0),p=θ)
tmp_sol = solve(tmp_prob, Tsit5(), saveat=0.1)
# likelihood
loglik =Vector{Real}(undef,T)
for t ∈ 1:T #
loglik[t] = logpdf(MvNormal(sol[obs_symb][:][t],sigma), data_cat[:,t]) # allocating; inefficient
end
return(-sum(loglik)) # negative
end
# test AD
using ForwardDiff, Zygote
using SciMLSensitivity
loglik(p0)
fd_grad = ForwardDiff.gradient(loglik,p0) # works fine
fd_grad = Zygote.gradient(loglik,p0) # mutation!