I am trying to train a neural ODE on different time trajectories, starting with different initial conditions. I am able to observe a reduction in loss with ADAM (with good outputs), but BFGS does not work at all. Why is that the case? What else can I try here to get better results?
Are you pre-starting with ADAM? BFGS can get stuck in local minima quite easily.
Yes, starting with ADAM followed by BFGS, as in the tutorials.
Show the code.
Here is the code.
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using Optim, Flux, Lux, OrdinaryDiffEq, ComponentArrays, Statistics
using IterTools: ncycle
using Random
ann = Lux.Chain(Lux.Dense(2,10,Lux.tanh),
Lux.Dense(10,10,Lux.tanh),
Lux.Dense(10,1))
p, st = Lux.setup(rng, ann)
group_size = 3
continuity_term = 200
function loss_function(data, pred)
return sum(abs2, data - pred)
end
function loss_multiple_shooting(p, data)
y_train = data[2:end]'
x0 = data[1]
# ODE problem parametrized by initial condition
prob = ODEProblem((u,p,t)->ann([x0, u[1]], p, st)[1], x0, tspan, p)
return multiple_shoot(p, y_train, tsteps, prob, loss_function, Tsit5(),
group_size; continuity_term)
end
#train_data contains time trajectories and initial conditions
train_loader = Flux.Data.DataLoader((train_data, ), batchsize=1)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p, batch) -> loss_multiple_shooting(x, batch), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float32}(p))
result_neuralode = Optimization.solve(optprob, ADAM(0.01),
ncycle(train_loader, 100))
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm =0.01),
allow_f_increases = true,
ncycle(train_loader, 20))
I see you’re using a data loader. BFGS doesn’t work well with stochastic loss functions
Is there a way to generate multiple time series starting with different initial conditions?
Just solve at each. Or make it a matrix if you have an NN like that. But just make sure you do all initial conditions every time if you use BFGS. It’s not an optimizer for changing what you evaluate.