using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
rng = Random.default_rng()
u0 = Float32[1.0, 1.0]
datasize = 80
tspan = (0.0f0, 10f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
p = Float32[1.5, 1.0, 3.0, 1.0]
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = αx - βxy
du[2] = dy = -δy + γxy
end
prob_trueode = ODEProblem(lotka_volterra,u0,tspan,p)
ode_data = solve(prob_trueode, Tsit5(), saveat = tsteps)
plot(tsteps, ode_data[1,:], label = [“x” “y”], xlabel = “Time”, ylabel = “Population”, title = “Lotka-Volterra”, lw = 3)
dudt2 = Lux.Chain(x → x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
Do not plot by default for the documentation
Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = false)
println(l)
plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = “data”)
scatter!(plt, tsteps, pred[1,:], label = “prediction”)
display(plot(plt))
end
return false
end
pinit = Lux.ComponentArray(p)
callback(pinit, loss_neuralode(pinit)…; doplot=true)
use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) → loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(0.05),
callback = callback,
maxiters = 500)
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)…; doplot=true)