Instabilities on a simple dataset when training a NODE on it

I have an extremely simple dataset that for some reason keeps outputing instabilities when trying to train a NODE algorithm on it, no matter how low the tolerance is set or what solver I choose. I train it twice, once with ADAM and another one with BFGS. Sometimes it does well with ADAM and then becomes unstable as soon as it reaches BFGS. Sometimes it’s unstable from the first ADAM iteration.

using DiffEqFlux, Lux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random
rng = Random.default_rng()

data = [90.0 120.0 150.0 180.0 210.0;0.003556 0.00408229 0.00739525 0.0160948 0.00926306]
chain = Lux.Chain(Lux.Dense(2,4,relu),Lux.Dense(4,2))
ps, st = Lux.setup(rng,chain)
pinit = Float64.(ComponentArray(ps))
neuralode = NeuralODE(chain, (data[1,1],data[1,end]), RadauIIA5(),saveat=data[1,:],reltol=1e-18,abstol=1e-18)
function predict_neuralode(p)
        first(neuralode(data[:,1],p,st))
end

function loss_function(p)
        pred = predict_neuralode(p)
        loss = sum(abs2, data .- pred)
        return loss
end

losses = Float64[]
callback = function(p,l)
    push!(losses, l)
    if length(losses)%1==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf,pinit)

result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           callback = callback,
                                           maxiters = 300)
optprob2 = remake(optprob,u0 = result_neuralode.u)
optimizedParameters = Optimization.solve(optprob2,
                                            Optim.BFGS(),
                                            callback=callback,
                                            allow_f_increases = false)

Any ideas how to address this?

Is it unstable at the start? If the first solve is already unstable then optimization can’t even get started! You may want to tame the neural ODE by changing your starting parameters to be smaller, for example change pinit = Float64.(ComponentArray(ps))/100. Or start with a smaller interval. See:

Running predict_neuralode(pinit) did run correctly, but something I forgot to mention is that the loss function outputs are absurd throughout the optimization (some runs output loss_function(pinit) in the order of 1e60). A reduced training length didn’t work either.

Shrinking pinit worked great though, thanks!

I guess my first iteration was a lucky one. I’ve tried running it again (even after using just the first 3 time steps) and ADAM runs well, but BFGS eventually reaches a place where even this low of a tolerance causes instabilities.

Try backtracking linesearch in BFGS?

That worked out great even with a higher tolerance, thanks!