Oscillating loss curves using sciml_train

I’ve been trying to adapt one of the examples from Chris Rackauckas’s UDE paper, with the code here to a slightly more complicated form of equations and am running into difficulties with oscillatory loss curves. In his example, the nonlinear terms in the LV equations are approximated with a neural network, so
\frac{dx}{dt} = ax - bxy
\frac{dy}{dt} = cxy - dy
\frac{dx}{dt} = ax + U_1([x,y],p)
\frac{dy}{dt}= - dy+ U_2([x,y],p)

In my example, I’m working with the following equations:
\frac{dx}{dt} = (ax - bxy)*xt
\frac{dy}{dt} = (cxy - dy)*yt
Which I am representing with a NN and its derivative as:
\frac{dx}{dt} = \left(ax - by\frac{\partial}{\partial t}[U_1([x,y,t],p)]\right)U_1([x,y,t],p)
\frac{dy}{dt}= \left(- dy+ cx\frac{\partial}{\partial t}[U_2([x,y,t],p)]\right)U_2([x,y,t],p)
I’m trying to get the output that U_1 = xt and U_2=yt.

I’ve taken the code from the above link and made a couple of minor changes to accommodate this:

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, Optim
using DiffEqFlux, Flux
using Plots

using Random

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = (α*u[1] - β*u[2]*u[1]) * (u[1]*t)
    du[2] = (γ*u[1]*u[2]  - δ*u[2]) * (u[2]*t) 

tspan = (0.0f0,3.0f0)
u0 = Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)

X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
rbf(x) = exp.(-(x.^2))

U = FastChain(FastDense(3,5,rbf), FastDense(5,5, rbf), FastDense(5,5, rbf), FastDense(5,2))
p = initial_params(U)

function ude_dynamics!(du,u, p, t, p_true)
    x,y = u
    A = [x,y,t]
    û = U(A, p) # Network prediction
    dûdt = DiffEqFlux.jacobian(A->U(A,p),A)[1][:,3]
    du[1] = (p_true[1]*u[1] - p_true[2]*dûdt[1]*u[2]) * û[1]
    du[2] = (-p_true[4]*u[2] + p_true[3]*dûdt[2]*u[1]) * û[2]

nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)
prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:,1], T = t)
    Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
                tspan = (T[1], T[end]), saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = ForwardDiffSensitivity()

function loss(θ)
    X̂ = predict(θ)
    sum(abs2, Xₙ .- X̂)

losses = Float32[]

callback(θ,l) = begin
    push!(losses, l)
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")

res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.1f0), cb=callback, maxiters = 1000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 20000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

I used 1000 steps of ADAM and then 20000 steps of BFGS, but my loss plots look quite poor – they jump around and oscillate, as shown below. I’ve tried various step numbers, but they’re all kick me out of training early – after a couple hundred iterations of BFGS. I think I have some misunderstanding of when to use ADAM vs. BFGS (why do I use one, then the other? Should I use another one, add a third one in-between, etc.)
Screen Shot 2022-01-07 at 3.33.49 PM
I don’t think I fully understand why it kicks me out in the BFGS step, when the error is decreasing again – and not in the ADAM step when it’s jumping around. The error I get is " Warning: Interrupted. Larger maxiters is needed.", but when I look at this, it seems like I almost have too many ADAM iterations considering the error jumps up again at a couple hundred.

Any help to resolve this would be greatly appreciated.

That’s from having a high rate on ADAM. Drop it a bit, like ADAM(1f-2). Also decrease your solver tolerances so you get more accurate gradients.

Then for the BFGS part, you don’t have a branch to catch divergent solves. Handle that with an if sol.retcode !== :Success branch as the documentation shows. Also, you might need to allow for integrators with switching like AutoVern7(TRBDF2()) as that looks indicative of finding a stiffer problem.

Thank you! This is very helpful. There are a lot of knobs to tweak and it seems to take a lot of practice to learn which knobs to turn and which way to turn them.

I’m also a little confused on applying ContinuousDataDrivenProblem here – in the code example
ContinuousDataDrivenProblem(X̂, ts, DX = Ŷ)
\hat{X} is the solution to the problem (the timeseries of x,y values given the fitted parameters) and \hat{Y} is the neural network evaluated with the fitted parameters and \hat{X}. I’m not sure I see why DX=\hat{Y}. Maybe it’s just some fundamental misunderstanding I have. Do I have to modify this approach for the more complicated differential equations? Thanks in advance.