I have an extremely simple dataset that for some reason keeps outputing instabilities when trying to train a NODE algorithm on it, no matter how low the tolerance is set or what solver I choose. I train it twice, once with ADAM and another one with BFGS. Sometimes it does well with ADAM and then becomes unstable as soon as it reaches BFGS. Sometimes it’s unstable from the first ADAM iteration.
using DiffEqFlux, Lux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random
rng = Random.default_rng()
data = [90.0 120.0 150.0 180.0 210.0;0.003556 0.00408229 0.00739525 0.0160948 0.00926306]
chain = Lux.Chain(Lux.Dense(2,4,relu),Lux.Dense(4,2))
ps, st = Lux.setup(rng,chain)
pinit = Float64.(ComponentArray(ps))
neuralode = NeuralODE(chain, (data[1,1],data[1,end]), RadauIIA5(),saveat=data[1,:],reltol=1e-18,abstol=1e-18)
function predict_neuralode(p)
first(neuralode(data[:,1],p,st))
end
function loss_function(p)
pred = predict_neuralode(p)
loss = sum(abs2, data .- pred)
return loss
end
losses = Float64[]
callback = function(p,l)
push!(losses, l)
if length(losses)%1==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf,pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
callback = callback,
maxiters = 300)
optprob2 = remake(optprob,u0 = result_neuralode.u)
optimizedParameters = Optimization.solve(optprob2,
Optim.BFGS(),
callback=callback,
allow_f_increases = false)
Any ideas how to address this?