i am using multiple shoots to learn the lotka volterra equations. how can i predict results using the updated parameters?
emphasized textusing ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges
using Random
rng = Random.default_rng()
Define initial conditions and time steps
u0 = Float32[1.0 ; 1.0]
tspan = (0.0f0, 10.0f0)
datasize = 70
tsteps = range(tspan[1], tspan[2], length = datasize)
function lotka_volterra(du,u,p,t)
x, y = u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = αx - βxy
du[2] = dy = -δy + γxy
end
prob = ODEProblem(lotka_volterra,u0,tspan)
Verify ODE solution
ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))
anim = Plots.Animation()
Define the Neural Network
nn = Lux.Chain(x → x.^3,
Lux.Dense(2, 84, swish),
Lux.Dense(84, 44, swish),
Lux.Dense(44, 22, swish),
Lux.Dense(22, 12, swish),
Lux.Dense(12,2))
p_init, st = Lux.setup(rng, nn)
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)
for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
end
end
Animate training, cannot make animation on CI server
anim = Plots.Animation()
iter = 0
callback = function (p, l, preds; doplot = true)
display(l)
global iter
iter += 1
if doplot && iter%1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = “Data”)
# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)
frame(anim,plt)
display(plot(plt))
end
return false
end
Define parameters for Multiple Shooting
group_size = 8
continuity_term = 200
function loss_function(data, pred)
return sum(abs2, data - pred)
end
function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
group_size; continuity_term)
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) → loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(),
callback = callback)
gif(anim, “multiple_shooting.gif”, fps=15)