''Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots,ComponentArrays
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)‘true_A)’
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = Lux.Chain(x → x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
Do not plot by default for the documentation
Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = false)
println(l)
plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = “data”)
scatter!(plt, tsteps, pred[1,:], label = “prediction”)
display(plot(plt))
end
return false
end
pinit = Lux.ComponentArray(p)
callback(pinit, loss_neuralode(pinit)…; doplot=true)
use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) → loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(0.05),
callback = callback,
maxiters = 300)
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)…; doplot=true)‘’