Effect of Interpolation on ODE Gradient Computation

I am interested in understanding how to optimize the coefficients of an ODE (such as Lotka-Volterra) under some cost function given that this ODE is influenced by a set of controls c1(t), …, cn(t) for which I observe simulated trajectories. As a simple example, here I’ve written a MRE with one simulation of Lotka-Volterra where the second coordinate is subjected to a time-varying control. I want to take the gradient of the square cost of this simulation with respect to the Lotka-Volterra parameters and see how this adjoint-based gradient computation is affected by the control input which is an interpolated function.

cd(@__DIR__)

using Pkg
Pkg.activate(".")
Pkg.instantiate()

using Flux
using Zygote
using ForwardDiff
using DifferentialEquations
using SciMLSensitivity
using Optimization
using Random
using MAT
using Interpolations
using Random
using Profile
using FlameGraphs

# Define Controlled ODE
struct ControlledODE{Of, So, Se, T, Sa}
    odefunc::Of
    solver::So
    sensealg::Se
    tspan::T
    saveat::Sa
end

function ControlledODE(odefunc;
    solver=Tsit5(),
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
    tspan=(0.0f0, 1.0f0),
    saveat=[])
    return ControlledODE(odefunc, solver, sensealg, tspan, saveat)
end

function (c::ControlledODE)(u0, control; tspan=c.tspan, saveat=c.saveat)
    function ode(du, u, p, t)
        c.odefunc(du, u, t)
        du[length(du)] += control(t)
    end
    prob = ODEProblem(ode, u0, tspan)
    return solve(prob, c.solver; sensealg=c.sensealg, saveat=saveat)
end

Flux.@functor ControlledODE


# Lotka Volterra with control input
tspan = (0.0, 10.0)
Nt = 100
tsteps = LinRange(tspan[1], tspan[2], Nt)

v = sin.(pi.*tsteps)
control = linear_interpolation(tsteps, v)

p = Float64[2.2, 1.0, 2.0, 0.4]
function lotka_volterra(du, u, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = (α - β*y)x
    du[2] = (δ*x - γ)y
end

u0 = Float64[1.0, 1.0]

sensealg = SciMLSensitivity.InterpolatingAdjoint(; autojacvec=ZygoteVJP())
model = ControlledODE(lotka_volterra; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=tsteps)

# Cost for simulation
function cost(model)
    return sum(abs2, model(u0, control))
end

# Compute the gradient
Flux.gradient(m -> cost(m), model)

I am currently dealing with a bug in my code which makes the gradient computation fails coming from the fact that Flux cannot see how the parameter vector p has any effect on the model simulation.

Any help would be much appreciated!

Don’t do anything fancy. Just make it a function of p.

cd(@__DIR__)

using Pkg
Pkg.activate(".")
Pkg.instantiate()

using Flux
using Zygote
using ForwardDiff
using DifferentialEquations
using SciMLSensitivity
using Optimization
using Random
using MAT
using Interpolations
using Random
using Profile
using FlameGraphs

# Define Controlled ODE
struct ControlledODE{Of, So, Se, T, Sa}
    odefunc::Of
    solver::So
    sensealg::Se
    tspan::T
    saveat::Sa
end

function ControlledODE(odefunc;
    solver=Tsit5(),
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
    tspan=(0.0f0, 1.0f0),
    saveat=[])
    return ControlledODE(odefunc, solver, sensealg, tspan, saveat)
end

function (c::ControlledODE)(u0, control, p; tspan=c.tspan, saveat=c.saveat)
    function ode(du, u, _p, t)
        c.odefunc(du, u, _p, t)
        du[length(du)] += control(t)
    end
    prob = ODEProblem(ode, u0, tspan, p)
    return solve(prob, c.solver; sensealg=c.sensealg, saveat=saveat)
end

Flux.@functor ControlledODE

# Lotka Volterra with control input
tspan = (0.0, 10.0)
Nt = 100
tsteps = LinRange(tspan[1], tspan[2], Nt)

v = sin.(pi.*tsteps)
control = linear_interpolation(tsteps, v)

p = Float64[2.2, 1.0, 2.0, 0.4]
function lotka_volterra(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = (α - β*y)*x
    du[2] = (δ*x - γ)*y
end

u0 = Float64[1.0, 1.0]

sensealg = SciMLSensitivity.InterpolatingAdjoint(; autojacvec=ZygoteVJP())
model = ControlledODE(lotka_volterra; solver=Tsit5(), sensealg=sensealg, 
                      tspan=tspan, saveat=tsteps)

# Cost for simulation
function cost(p)
    return sum(abs2, model(u0, control, p))
end

# Compute the gradient
Flux.gradient(_p -> cost(_p), p)

Thanks @ChrisRackauckas! @Oscar_Smith not sure what the best way is to plot a flamegraph, I have it either as an svg or html file.

The image appears to be missing.

Use GitHub - tkluck/StatProfilerHTML.jl: Show Julia profiling data in an explorable HTML page and then share the whole folder.

Perfect, here is the whole folder from StatProfiler in a google drive folder: https://drive.google.com/drive/folders/1BRddREe0AdeUPkxvexiZTXV5BIcmK74e?usp=drive_link

It looks like you profiled compilation. Run it then profile it.

Darn it, I made sure to run the gradient computation once before. And then I ran @profilehtml “gradient computation”.

But somehow it seems to have again profiled the compilation. Hopefully now the updated google drive has the right thing.

Let me know if it’s still incorrect. I’m running this exactly:

cd(@__DIR__)

using Pkg
Pkg.activate(".")
Pkg.instantiate()

using Flux
using Zygote
using ForwardDiff
using DifferentialEquations
using SciMLSensitivity
using Optimization
using Random
using MAT
using Interpolations
using Random
using StatProfilerHTML
using FlameGraphs
using ImageShow

# Define Controlled ODE
struct ControlledODE{Of, So, Se, T, Sa}
    odefunc::Of
    solver::So
    sensealg::Se
    tspan::T
    saveat::Sa
end

function ControlledODE(odefunc;
    solver=Tsit5(),
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
    tspan=(0.0f0, 1.0f0),
    saveat=[])
    return ControlledODE(odefunc, solver, sensealg, tspan, saveat)
end

function (c::ControlledODE)(u0, control, p; tspan=c.tspan, saveat=c.saveat)
    function ode(du, u, _p, t)
        c.odefunc(du, u, _p, t)
        du[length(du)] += control(t)
    end
    prob = ODEProblem(ode, u0, tspan, p)
    return solve(prob, c.solver; sensealg=c.sensealg, saveat=saveat)
end

Flux.@functor ControlledODE

# Lotka Volterra with control input
tspan = (0.0, 10.0)
Nt = 100
tsteps = LinRange(tspan[1], tspan[2], Nt)

v = sin.(pi.*tsteps)
control = linear_interpolation(tsteps, v)

p = Float64[2.2, 1.0, 2.0, 0.4]
function lotka_volterra(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = (α - β*y)*x
    du[2] = (δ*x - γ)*y
end

u0 = Float64[1.0, 1.0]

sensealg = SciMLSensitivity.InterpolatingAdjoint(; autojacvec=ZygoteVJP())
model = ControlledODE(lotka_volterra; solver=Tsit5(), sensealg=sensealg, 
                      tspan=tspan, saveat=tsteps)

# Cost for simulation
function cost(p)
    return sum(abs2, model(u0, control, p))
end

# Compute the gradient
Flux.gradient(_p -> cost(_p), p)

@profilehtml Flux.gradient(_p -> cost(_p), p)