Hi everyone,
I am trying to experiment with NeuralODEs and wanted to fit a simple cosine curve as an implementation exercise. I have tried experimenting the same with the solvers in torchdiffeq
and diffrax
but both seem to give a bad fit, and at the same time the runtimes are too high for a sensible solution.
I modified the example code to fit a cosine, but the resulting fit is pretty bad (both with Adam as well as tuning using BFGS) Any suggestions on how to improve the fit/model? Note that the runtime is pretty good but the model fit is bad.
using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL, OptimizationOptimisers, Random, Plots
rng = Random.default_rng();
# u0 = Float32[2.0; 0.0]
datasize = 101;
tspan = (-4.0f0*Ď€, 4.f0*Ď€);
tsteps = range(tspan[1], tspan[2]; length = datasize);
function trueODEfunc(du, u, p, t)
du .= -sin.(t);
end;
prob_trueode = ODEProblem(trueODEfunc, [1.0], tspan);
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps));
plt = scatter(tsteps, ode_data[1, :]; label="true data")
display(plot(plt))
dudt2 = Chain(x -> x.^3, Dense(1, 50, tanh), Dense(50, 1));
p, st = Lux.setup(rng, dudt2);
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps);
function predict_neuralode(p)
Array(prob_neuralode([0.0f0], p, st)[1])
end;
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end;
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1, :]; label = "data")
plot!(plt, tsteps, pred[1, :]; label = "prediction")
display(plot(plt))
end
return false
end;
pinit = ComponentArray(p);
callback(pinit, loss_neuralode(pinit)...; doplot = true)
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote();
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype);
optprob = Optimization.OptimizationProblem(optf, pinit);
result_neuralode = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.05); callback = callback,
maxiters = 300);
optprob2 = remake(optprob; u0 = result_neuralode.u);
result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = 0.01);
callback, allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true)
Thanks!