Prediction w multiple shoot

i am using multiple shoots to learn the lotka volterra equations. how can i predict results using the updated parameters?

emphasized textusing ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges

using Random
rng = Random.default_rng()

Define initial conditions and time steps

u0 = Float32[1.0 ; 1.0]
tspan = (0.0f0, 10.0f0)
datasize = 70

tsteps = range(tspan[1], tspan[2], length = datasize)
function lotka_volterra(du,u,p,t)
x, y = u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = αx - βxy
du[2] = dy = -δ
y + γxy
end

prob = ODEProblem(lotka_volterra,u0,tspan)

Verify ODE solution

ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))
anim = Plots.Animation()

Define the Neural Network

nn = Lux.Chain(x → x.^3,
Lux.Dense(2, 84, swish),
Lux.Dense(84, 44, swish),
Lux.Dense(44, 22, swish),
Lux.Dense(22, 12, swish),
Lux.Dense(12,2))
p_init, st = Lux.setup(rng, nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))

function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)

``````for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group \$(i)")
end
``````

end

anim = Plots.Animation()

iter = 0
callback = function (p, l, preds; doplot = true)
display(l)
global iter
iter += 1
if doplot && iter%1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = “Data”)

``````# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)

frame(anim,plt)
display(plot(plt))
``````

end
return false
end

Define parameters for Multiple Shooting

group_size = 8
continuity_term = 200

function loss_function(data, pred)
return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
group_size; continuity_term)
end

optf = Optimization.OptimizationFunction((x,p) → loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(),
callback = callback)
gif(anim, “multiple_shooting.gif”, fps=15)

A minor formal issue: if you enclose your source code with triple backticks, the code will not only be better displayed but more convenient to copy and paste.

For example

``````x = 1:10
``````

Here’s your code properly indented and formatted for Markdown with triple backticks. There could be some editing errors in there. I did not try to run the code.

``````using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges

using Random
rng = Random.default_rng()

# Define initial conditions and time steps

u0 = Float32[1.0 ; 1.0]
tspan = (0.0f0, 10.0f0)
datasize = 70

tsteps = range(tspan[1], tspan[2], length = datasize)

function lotka_volterra(du,u,p,t)
x, y = u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = α*x - β*x<em>y
du[2] = dy = -δ</em>y + γ*x*y
end

prob = ODEProblem(lotka_volterra,u0,tspan)

# Verify ODE solution

ode_data =Array(solve(prob, Tsit5(), saveat = tsteps))
anim = Plots.Animation()

# Define the Neural Network

nn = Lux.Chain(x → x.^3,
Lux.Dense(2, 84, swish),
Lux.Dense(84, 44, swish),
Lux.Dense(44, 22, swish),
Lux.Dense(22, 12, swish),
Lux.Dense(12,2))
p_init, st = Lux.setup(rng, nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))

function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)

for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group \$(i)")
end
end

# Animate training, cannot make animation on CI server

# anim = Plots.Animation()

iter = 0
callback = function (p, l, preds; doplot = true)
display(l)
global iter
iter += 1
if doplot && iter%1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = “Data”)

# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)

frame(anim,plt)
display(plot(plt))
end
return false
end

# Define parameters for Multiple Shooting

group_size = 8
continuity_term = 200

function loss_function(data, pred)
return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
end

Just do the same as the tutorials. `res_ms.u` is the learned parameters, so `remake(prob, p = res_ms.u)`. It’s no different than the `remake` in the callback.