Training a neural ode with unknow time span

I am trying to train a neural ode with a start time = 0 and an unknown end time.

Here is the working code.

using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots, ComponentArrays
y_train = [1.0,0.9983964550640194,0.9929125761989865,0.9792396014395183,0.9546845545572502,0.9154772688324724,
0.862462359536808,0.7983205620975837,0.7286092001860601,0.6583837246309399,0.5915979141675031,0.5284476216123583,
0.47150341518348965,0.4182436898670648,0.3720713883516537,0.3309667784659828,0.29715768599897174,0.2664332753935418,
0.23903836267045317,0.21234117560653165,0.18718632947340072,0.16399001150635295,0.14353563297182165,0.12440326094939652,
0.10676426665360979,0.0908634661052219,0.07808406982152909,0.06742233211741375,0.05613631356035937,0.045021666217837306,
0.035106617377041134,0.026513575048351168,0.018471368766372075,0.011041202536293943,0.005116654833892321,0.0]
display(plot(y_train))
tend = 1.0f0 #Chosen Randomly
tspan = (0.0f0,tend)
tsteps = collect(range(0, tend, length=length(y_train)))
ann = Lux.Chain(Lux.Dense(1,10,Lux.tanh),
                Lux.Dense(10,10,Lux.tanh),
                Lux.Dense(10,1))
p, st = Lux.setup(rng, ann) 
prob_neuralode = NeuralODE(ann, tspan, Tsit5(), saveat = tsteps)
u0 = [1f0]

function predict_neuralode(p)
  Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, y_train .- pred)
    return loss, pred
end
callback = function (p, l, pred)
  println(l)
  return false
end
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float32}(p))

result_neuralode = Optimization.solve(optprob,
                                       ADAM(0.001),
                                       callback = callback,
                                       maxiters = 1000)

plot(predict_neuralode(result_neuralode.u)')

I have tried using different end time for time span, different ode solvers, but the loss saturates quickly to a high value. What else can I try here?

Did you try the standard suggestions from the docs?

https://docs.sciml.ai/SciMLSensitivity/stable/tutorials/training_tips/local_minima/

https://docs.sciml.ai/DiffEqFlux/stable/examples/multiple_shooting/

Multiple Shooting worked perfectly.
Thank You

1 Like

Does multiple shooting work with UDEs?

Yes, and is recommended for most real-world cases.

1 Like