# Oscillating loss curves using sciml_train

I’ve been trying to adapt one of the examples from Chris Rackauckas’s UDE paper, with the code here to a slightly more complicated form of equations and am running into difficulties with oscillatory loss curves. In his example, the nonlinear terms in the LV equations are approximated with a neural network, so
\frac{dx}{dt} = ax - bxy
\frac{dy}{dt} = cxy - dy
becomes
\frac{dx}{dt} = ax + U_1([x,y],p)
\frac{dy}{dt}= - dy+ U_2([x,y],p)

In my example, I’m working with the following equations:
\frac{dx}{dt} = (ax - bxy)*xt
\frac{dy}{dt} = (cxy - dy)*yt
Which I am representing with a NN and its derivative as:
\frac{dx}{dt} = \left(ax - by\frac{\partial}{\partial t}[U_1([x,y,t],p)]\right)U_1([x,y,t],p)
\frac{dy}{dt}= \left(- dy+ cx\frac{\partial}{\partial t}[U_2([x,y,t],p)]\right)U_2([x,y,t],p)
I’m trying to get the output that U_1 = xt and U_2=yt.

I’ve taken the code from the above link and made a couple of minor changes to accommodate this:

using OrdinaryDiffEq
using ModelingToolkit
using LinearAlgebra, Optim
using DiffEqFlux, Flux
using Plots

using Random
Random.seed!(1234)

function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = (α*u[1] - β*u[2]*u[1]) * (u[1]*t)
du[2] = (γ*u[1]*u[2]  - δ*u[2]) * (u[2]*t)
end

tspan = (0.0f0,3.0f0)
u0 = Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)

X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
rbf(x) = exp.(-(x.^2))

U = FastChain(FastDense(3,5,rbf), FastDense(5,5, rbf), FastDense(5,5, rbf), FastDense(5,2))
p = initial_params(U)

function ude_dynamics!(du,u, p, t, p_true)
x,y = u
A = [x,y,t]
û = U(A, p) # Network prediction
dûdt = DiffEqFlux.jacobian(A->U(A,p),A)[1][:,3]
du[1] = (p_true[1]*u[1] - p_true[2]*dûdt[1]*u[2]) * û[1]
du[2] = (-p_true[4]*u[2] + p_true[3]*dûdt[2]*u[1]) * û[2]
end

nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)
prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:,1], T = t)
Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
tspan = (T[1], T[end]), saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
))
end

function loss(θ)
X̂ = predict(θ)
sum(abs2, Xₙ .- X̂)
end

losses = Float32[]

callback(θ,l) = begin
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations:$(losses[end])")
end
false
end

res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.1f0), cb=callback, maxiters = 1000)
println("Training loss after $(length(losses)) iterations:$(losses[end])")

res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 20000)
println("Final training loss after $(length(losses)) iterations:$(losses[end])")

I used 1000 steps of ADAM and then 20000 steps of BFGS, but my loss plots look quite poor – they jump around and oscillate, as shown below. I’ve tried various step numbers, but they’re all kick me out of training early – after a couple hundred iterations of BFGS. I think I have some misunderstanding of when to use ADAM vs. BFGS (why do I use one, then the other? Should I use another one, add a third one in-between, etc.)

I don’t think I fully understand why it kicks me out in the BFGS step, when the error is decreasing again – and not in the ADAM step when it’s jumping around. The error I get is " Warning: Interrupted. Larger maxiters is needed.", but when I look at this, it seems like I almost have too many ADAM iterations considering the error jumps up again at a couple hundred.

Any help to resolve this would be greatly appreciated.

That’s from having a high rate on ADAM. Drop it a bit, like ADAM(1f-2). Also decrease your solver tolerances so you get more accurate gradients.

Then for the BFGS part, you don’t have a branch to catch divergent solves. Handle that with an if sol.retcode !== :Success branch as the documentation shows. Also, you might need to allow for integrators with switching like AutoVern7(TRBDF2()) as that looks indicative of finding a stiffer problem.

Thank you! This is very helpful. There are a lot of knobs to tweak and it seems to take a lot of practice to learn which knobs to turn and which way to turn them.

I’m also a little confused on applying ContinuousDataDrivenProblem here – in the code example