I am interested in understanding how to optimize the coefficients of an ODE (such as Lotka-Volterra) under some cost function given that this ODE is influenced by a set of controls c1(t), …, cn(t) for which I observe simulated trajectories. As a simple example, here I’ve written a MRE with one simulation of Lotka-Volterra where the second coordinate is subjected to a time-varying control. I want to take the gradient of the square cost of this simulation with respect to the Lotka-Volterra parameters and see how this adjoint-based gradient computation is affected by the control input which is an interpolated function.
cd(@__DIR__)
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using Flux
using Zygote
using ForwardDiff
using DifferentialEquations
using SciMLSensitivity
using Optimization
using Random
using MAT
using Interpolations
using Random
using Profile
using FlameGraphs
# Define Controlled ODE
struct ControlledODE{Of, So, Se, T, Sa}
odefunc::Of
solver::So
sensealg::Se
tspan::T
saveat::Sa
end
function ControlledODE(odefunc;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
saveat=[])
return ControlledODE(odefunc, solver, sensealg, tspan, saveat)
end
function (c::ControlledODE)(u0, control; tspan=c.tspan, saveat=c.saveat)
function ode(du, u, p, t)
c.odefunc(du, u, t)
du[length(du)] += control(t)
end
prob = ODEProblem(ode, u0, tspan)
return solve(prob, c.solver; sensealg=c.sensealg, saveat=saveat)
end
Flux.@functor ControlledODE
# Lotka Volterra with control input
tspan = (0.0, 10.0)
Nt = 100
tsteps = LinRange(tspan[1], tspan[2], Nt)
v = sin.(pi.*tsteps)
control = linear_interpolation(tsteps, v)
p = Float64[2.2, 1.0, 2.0, 0.4]
function lotka_volterra(du, u, t)
x, y = u
α, β, δ, γ = p
du[1] = (α - β*y)x
du[2] = (δ*x - γ)y
end
u0 = Float64[1.0, 1.0]
sensealg = SciMLSensitivity.InterpolatingAdjoint(; autojacvec=ZygoteVJP())
model = ControlledODE(lotka_volterra; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=tsteps)
# Cost for simulation
function cost(model)
return sum(abs2, model(u0, control))
end
# Compute the gradient
Flux.gradient(m -> cost(m), model)
I am currently dealing with a bug in my code which makes the gradient computation fails coming from the fact that Flux cannot see how the parameter vector p has any effect on the model simulation.
Any help would be much appreciated!