Neural ODEs: Fitting a cosine curve

Hi everyone,

I am trying to experiment with NeuralODEs and wanted to fit a simple cosine curve as an implementation exercise. I have tried experimenting the same with the solvers in torchdiffeq and diffrax but both seem to give a bad fit, and at the same time the runtimes are too high for a sensible solution.

I modified the example code to fit a cosine, but the resulting fit is pretty bad (both with Adam as well as tuning using BFGS) Any suggestions on how to improve the fit/model? Note that the runtime is pretty good but the model fit is bad.

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL, OptimizationOptimisers, Random, Plots

rng = Random.default_rng();
# u0 = Float32[2.0; 0.0]
datasize = 101;
tspan = (-4.0f0*Ď€, 4.f0*Ď€);
tsteps = range(tspan[1], tspan[2]; length = datasize);

function trueODEfunc(du, u, p, t)
    du .= -sin.(t);
end;

prob_trueode = ODEProblem(trueODEfunc, [1.0], tspan);
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps));

plt = scatter(tsteps, ode_data[1, :]; label="true data")
display(plot(plt))

dudt2 = Chain(x -> x.^3, Dense(1, 50, tanh), Dense(50, 1));
p, st = Lux.setup(rng, dudt2);
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps);

function predict_neuralode(p)
    Array(prob_neuralode([0.0f0], p, st)[1])
end;

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end;

callback = function (p, l, pred; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
        plt = scatter(tsteps, ode_data[1, :]; label = "data")
        plot!(plt, tsteps, pred[1, :]; label = "prediction")
        display(plot(plt))
    end
    return false
end;

pinit = ComponentArray(p);
callback(pinit, loss_neuralode(pinit)...; doplot = true)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote();

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype);
optprob = Optimization.OptimizationProblem(optf, pinit);

result_neuralode = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.05); callback = callback,
    maxiters = 300);

optprob2 = remake(optprob; u0 = result_neuralode.u);

result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = 0.01);
    callback, allow_f_increases = false)

callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true)

Thanks!

EDIT: See the next comment about initial condition

You need to use multiple shooting Multiple Shooting · DiffEqFlux.jl, single shooting for neural odes trying to learn periodic data (or data with similar structure) is almost always going to fail.

Or you could do a staged training Smoothed Collocation for Fast Two-Stage Training · DiffEqFlux.jl

1 Like

Why is the initial condition is 0 here? It should be cos(0)=1, shouldn’t it?

2 Likes

Chris Rackauckas gave a great (long) talk about related issues, and in Part 2 (starting around the 2h08m mark) he got into some techniques that you’ll need when fitting to oscillatory solutions.

He discusses multiple shooting and the collocation method @avikpal mentioned. But probably the easiest thing to try is to get an initial guess by fitting to just one oscillation, then use that as the initial guess to optimize the over more oscillations — “growing the time interval”. It may also be relevant to use the prediction error method, where you modify your ODE to penalize errors directly.

1 Like

Thenks for the pointers @moble @avikpal @tomaklutfu ! Appreciate the help.

The fit is definitely better with multiple shooting, and on a single (half) wave instead over multiple periods but still gets stuck in local minima. I’ve tried changing the optimizers but it is mostly around this local minimum that the optimization stalls.

Code
using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, OrdinaryDiffEq, Plots
using DiffEqFlux: group_ranges
using OptimizationOptimisers: Adam, AdaDelta, AdaGrad
using OptimizationOptimJL: BFGS

using Random
rng = Random.default_rng()

# Define initial conditions and time steps
datasize = 120;
u0 = [Float32(-1.0)];
tspan = map(Float32, (-π, π));
tsteps = range(tspan[1], tspan[2]; length = datasize);

# Get the data
function trueODEfunc(du, u, p, t)
    du .= -sin.(t);
end;
prob_trueode = ODEProblem(trueODEfunc, u0, tspan);
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps));
plt=plot(tsteps, ode_data[1, :]; label="data")
display(plt)
# Define the Neural Network
nn = Chain(Dense(1, 12, relu), Dense(12, 12, relu), Dense(12, 12, relu), Dense(12, 6, relu), Dense(6, 1));
p_init, st = Lux.setup(rng, nn);

neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps);
prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init));

function plot_multiple_shoot(plt, preds, group_size)
    step = group_size - 1
    ranges = group_ranges(datasize, group_size)

    for (i, rg) in enumerate(ranges)
        plot!(plt, tsteps[rg], preds[i][1, :]; markershape = :circle, label = "Group $(i)")
    end
end;

anim = Plots.Animation();
iter = 0;
callback = function (p, l, preds; doplot = true)
    display(l)
    global iter
    iter += 1
    if doplot && iter % 1 == 0
        # plot the original data
        plt = scatter(tsteps, ode_data[1, :]; label = "Data")

        # plot the different predictions for individual shoot
        plot_multiple_shoot(plt, preds, group_size)

        frame(anim)
        display(plot(plt))
    end
    return false
end;

# Define parameters for Multiple Shooting
group_size = 10;
continuity_term = 200;

function loss_function(data, pred)
    return sum(abs2, data - pred)
end;

ps = ComponentArray(p_init);
pd, pax = getdata(ps), getaxes(ps);

function loss_multiple_shooting(p)
    ps = ComponentArray(p, pax)
    return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(),
        group_size; continuity_term)
end;

adtype = Optimization.AutoZygote();
optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype);
optprob = Optimization.OptimizationProblem(optf, pd);
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters=300);
gif(anim, "multiple_shooting.gif"; fps = 15);