DiffEq has a tutorial on multiple shooting for neural ode’s. It works for a single trajectory. I have a dataset with several trajectories, each trajectory being a solution to a differential equation with different initial condition. I have an implementation that runs but I doubt it is efficient or the best way of implementing it. Can you tell me whether there are packages with tutorials that demonstrate how to do this properly and/or show me how I can improve my code?

In general terms, I took the part of the tutorial

```
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
```

and replaced it with

```
function loss_multiple_shooting(p, p_axes, group_size, dataset, ode_problem)
ps = ComponentArray(p, p_axes)
return multiple_shoot(ps, dataset, time_steps, ode_problem, loss_function,
time_stepper, group_size; continuity_term)
end
function solution_to_parameters(solution::Vector{Float32})::ModelParams
reshaped = reshape(solution,(Nx,Nx+1))
return (
weight = reshaped[:,1:end-1],
bias = reshape(reshaped[:,end],(:,1))
)
end
...
losses = []
for (group_size, epochs) in epochs_per_group_size
for epoch = ProgressBar(1:epochs)
loss = 0
for batch in ProgressBar(train_loader)
trajectory = batch[1]
u0 = OffsetArrays.no_offset_view(trajectory[:, 1])
ps = ComponentArray(parameters)
p_data, p_axes = getdata(ps), getaxes(ps)
ode_problem = ODEProblem(
(u, p, t) -> model(u, p, st)[1],
u0,
time_span,
ps
)
opt_func = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x, p_axes, group_size, trajectory, ode_problem), ad_type)
opt_prob = Optimization.OptimizationProblem(opt_func, p_data)
result_solve = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(), maxiters = solver_iterations; callback = callback);
global parameters = solution_to_parameters(result_solve.u)
loss += result_solve.objective
end
push!(losses, loss)
display((epoch, loss))
flush(stdout)
end
end
```

I do a lot of `ComponentArray`

calling. This seems inefficient. Is there a better way of updating u0 for the ode_problem and the parameters for the neural network model?