Multiple shooting with multiple trajectories: how to properly implement this?

DiffEq has a tutorial on multiple shooting for neural ode’s. It works for a single trajectory. I have a dataset with several trajectories, each trajectory being a solution to a differential equation with different initial condition. I have an implementation that runs but I doubt it is efficient or the best way of implementing it. Can you tell me whether there are packages with tutorials that demonstrate how to do this properly and/or show me how I can improve my code?

In general terms, I took the part of the tutorial

function loss_multiple_shooting(p)
    ps = ComponentArray(p, pax)
    return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
        Tsit5(), group_size; continuity_term)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)

and replaced it with

function loss_multiple_shooting(p, p_axes, group_size, dataset, ode_problem)
    ps = ComponentArray(p, p_axes)
    return multiple_shoot(ps, dataset, time_steps, ode_problem, loss_function,
        time_stepper, group_size; continuity_term)
end

function solution_to_parameters(solution::Vector{Float32})::ModelParams
    reshaped = reshape(solution,(Nx,Nx+1))
    return (
        weight = reshaped[:,1:end-1], 
        bias = reshape(reshaped[:,end],(:,1))
    )
end

...

losses = []
for (group_size, epochs) in epochs_per_group_size
    for epoch = ProgressBar(1:epochs)
        loss = 0
        for batch in ProgressBar(train_loader)
            trajectory = batch[1]
            u0 = OffsetArrays.no_offset_view(trajectory[:, 1])

            ps = ComponentArray(parameters)
            p_data, p_axes = getdata(ps), getaxes(ps)
                        
            ode_problem = ODEProblem(
                (u, p, t) -> model(u, p, st)[1], 
                u0, 
                time_span, 
                ps
            )

            opt_func = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x, p_axes, group_size, trajectory, ode_problem), ad_type)
            opt_prob = Optimization.OptimizationProblem(opt_func, p_data)
            result_solve = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(), maxiters = solver_iterations; callback = callback);

            global parameters = solution_to_parameters(result_solve.u)

            loss += result_solve.objective
        end
        push!(losses, loss)
        display((epoch, loss))
        flush(stdout)
    end
end

I do a lot of ComponentArray calling. This seems inefficient. Is there a better way of updating u0 for the ode_problem and the parameters for the neural network model?

Hello and welcome to the community :wave:

Have you read the performance tips in the julia manual? In particular, the first point

Performance critical code should be inside a function

After that try to profile your code, e.g., using @profview and @profview_allocs in vscode, to learn what is taking time and get a feeling for where to possibly improve your code.

Hi Baggepinnen, thank you for the pointer. I wasn’t aware that I had to do so.

I placed it into a function like

function main_loop(model, parameters, epochs_per_group_size::Dict{Int, Int}, train_loader, solver_iterations::Int32)
    losses = []
    for (group_size, epochs) in epochs_per_group_size
        for epoch = ProgressBar(1:epochs)
            loss::Float32 = 0.0
            for batch in ProgressBar(train_loader)
                trajectory = batch[1]
                u0 = OffsetArrays.no_offset_view(trajectory[:, 1])

                ps = ComponentArray(parameters)
                p_data, p_axes = getdata(ps), getaxes(ps)

                ode_problem = ODEProblem(
                    (u, p, t) -> model(u, p, st)[1],
                    u0,
                    time_span, 
                    ps;
                    isoutofdomain = outofdomain
                )

                opt_func = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x, p_axes, group_size, trajectory, ode_problem), ad_type)
                opt_prob = Optimization.OptimizationProblem(opt_func, p_data)
                result_solve = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(); callback = callback, maxiters = solver_iterations, progress = true);

                parameters = solution_to_parameters(result_solve.u)

                loss += result_solve.objective
            end
            push!(losses, loss)
            display((epoch, loss))
            flush(stdout)
        end
    end
    return losses, parameters
end

Running profview and profview_allocs on them shows me some graphs, I do not completely understand. The graphs suggest most time/compute is spent on the solve function. In the meanwhile, I have found functions like remake and ncycle that seem to be alternative ways of doing part of the inner loop work. If someone could tell me whether using that actually matter, would help a lot. Otherwise, I’ll just have to see using some numerical tests.