Optimizing the computation of jacobians for multiple shooting implementation

Hi, I’ve been working with multiple shooting for trajectory optimization. I am currently using the sparse backend from DifferentiationInterface.jl and using DifferentialEquations.jl to perform the integration. A snippet of my current code is inserted below:

function compute_dynamics_arcs!(
    y::AbstractVector{T}, u::AbstractVector{T}, c
) where {T<:Real}
    dyn = c.dyn
    times = c.times
    integrator = c.integrator
    state_nodes = c.state_nodes
    state_size = c.state_size
    state_elements = (state_nodes - 1) * state_size

    column_start = state_elements + 1

    for i in 1:length(dyn.components)
        if :times in fieldnames(typeof(dyn.components[i]))
            control_size = length(dyn.components[i].control[1])
            control_nodes = length(dyn.components[i].control)
            control_elements = control_nodes * control_size

            @views controls = [
                SA[u[i:(i + control_size - 1)]...] for
                i in column_start:control_size:(column_start + control_elements - 1)
            ]
            dyn = @set dyn.components[i].control = controls

            column_start += control_elements
        end
    end

    # Using solve vs. step!/reinit! for AD tends to be faster as DifferentialEquations
    # provides a lot of optimizations for the solve function
    prob = ODEProblem(ode_dynamics, SA[u[1:state_size]...], (times[1], times[end]), dyn)

    j = 1

    for i in 1:state_size:(state_elements)
        @views state = SA[u[i:(i + state_size - 1)]...]

        cb = dynamics_callback(dyn, state, (times[j], times[j + 1]))
        tstops = dynamics_tstops(dyn, (times[j], times[j + 1]))

        y[i:(i + state_size - 1)] = solve(remake(prob; u0=state, tspan=(times[j], times[j + 1])), integrator; callback=cb, abstol=1e-8, reltol=1e-8, tstops=tstops, d_discontinuities=tstops, save_everystep=false).u[end]

        j += 1
    end

    return nothing
end
BenchmarkTools.Trial: 1507 samples with 1 evaluation per sample.
 Range (min … max):  2.105 ms … 42.442 ms  β”Š GC (min … max):  0.00% … 93.07%
 Time  (median):     2.282 ms              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   3.315 ms Β±  2.860 ms  β”Š GC (mean Β± Οƒ):  21.10% Β± 19.65%

  β–ˆβ–ƒβ–β–β–β–†β–„β–‚β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ƒβ–…β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–„β–…β–†β–‡β–‡β–‡β–ˆβ–ˆβ–ˆβ–„ β–ˆ
  2.11 ms      Histogram: log(frequency) by time     12.7 ms <

 Memory estimate: 4.36 MiB, allocs estimate: 15871.

I am having issues in reducing the allocations / initialization for a speed up. Profiling the code, it seems that approximately half of the computational time is spent in the initialization of the DifferentialEquations solver. When I tried to use the step!/reinit! approach as may be reasonable, performance was worse, which I think is because DifferentialEquations may have optimized codepath(s) when differentiation is detected on solve! calls. I was wondering if there was any way to save the underlying caches between repeated solve! calls.

Any thoughts on how this code may be improved? Thanks!

Hi, could you provide the full code of your example, including imports, function definitions, and whatever you’re actually measuring?

A screenshot of the profiling flame graph would also help

Thanks, unfortunately, I can’t provide the full code for the example as it would be far too long. However, the timings are taken from calling the jacobian function from DifferentiationInterface:

@benchmark jacobian!(
    compute_dynamics_arcs!,
    stm_buffers[n][:y],
    stm_buffers[n][:jac_buffer],
    stm_buffers[n][:jac_prep],
    stm_buffers[n][:backend],
    x,
    stm_buffers[n][:context],
)

Here is a flamegraph generated from calling the jacobian! function 1000 times:

It is clear from this that approximately 40% of the time is spent in the init phase. Because of the discretization used with multiple shooting, it is difficult to reduce the number of separate integrations required.

Without the full code it will be very difficult to help you

Can you at least show ode_dynamics? Perhaps share the full code as a Pluto notebook?

The ode_dynamics is a passthrough that iteratively calls all separate parts of the dynamics (think state and control dynamics for multiple shooting). I am quite confident that there are no issues in this part of the code that would be causing this issue.

I am primarily wondering if there is a way to avoid the repeated cost of __init calls when solve is called in the context of using AD on the solver, as using the reinit! function tended to be slower.

I’ll get back to you soon with a smaller example

1 Like

That might be more of a question for BoundaryValueDiffEq.jl developers then

I looked at the source for BoundaryValueDiffEq.jl and they seemed to be using the reinit!/step method just fine. I will retry with this approach again.

Thank you for your help.

FYI, here is the smaller testing example I constructed:

using DifferentiationInterface
using SparseArrays
using ADTypes: KnownJacobianSparsityDetector
using SparseMatrixColorings
using BenchmarkTools
using OrdinaryDiffEq
using StaticArrays

import ForwardDiff

function dynamics_example(
    u::SVector{N}, p, t::Real
) where {N<:Any}
    @assert N == 6 "Dynamics model requires 6D state vector"

    acc = -u[1:3] / (u[1]^2 + u[2]^2 + u[3]^2)^(3 / 2)

    return SVector{N}(u[4], u[5], u[6], acc[1], acc[2], acc[3])
end

function compute_dynamics_arcs_solve!(
    y::AbstractVector{T}, u::AbstractVector{T}, c
) where {T<:Real}
    times = c.times
    integrator = c.integrator
    state_nodes = c.state_nodes
    state_size = c.state_size
    state_elements = (state_nodes - 1) * state_size

    prob = ODEProblem(dynamics_example, SA[u[1:state_size]...], (times[1], times[end]))

    j = 1

    for i in 1:state_size:(state_elements)
        state = SA[u[i:(i + state_size - 1)]...]

        y[i:(i + state_size - 1)] = solve(remake(prob; u0=state, tspan=(times[j], times[j + 1])), integrator; abstol=1e-8, reltol=1e-8, save_everystep=false, tstops = times, d_discontinuities = times).u[end]

        j += 1
    end

    return nothing
end

function compute_dynamics_arcs_step_reinit!(
    y::AbstractVector{T}, u::AbstractVector{T}, c
) where {T<:Real}
    times = c.times
    integrator = c.integrator
    state_nodes = c.state_nodes
    state_size = c.state_size
    state_elements = (state_nodes - 1) * state_size

    prob = ODEProblem(dynamics_example, SA[u[1:state_size]...], (times[1], times[end]))
    integ = init(prob, integrator; abstol=1e-8, reltol=1e-8, save_everystep=false, tstops = times, d_discontinuities = times)

    j = 1

    for i in 1:state_size:(state_elements)
        state = SA[u[i:(i + state_size - 1)]...]

        reinit!(integ, state, t0 = times[j], tf = times[j + 1])

        step!(integ, times[j + 1] - times[j], true)

        y[i:(i + state_size - 1)] = integ.u

        j += 1
    end

    return nothing
end

jac_pattern = spzeros(600, 600)

for i in 1:6:600
    jac_pattern[i:(i + 5), i:(i + 5)] .= true
end


sparsity_detector = KnownJacobianSparsityDetector(
    jac_pattern
)

backend = AutoSparse(
    AutoForwardDiff();
    sparsity_detector=sparsity_detector,
    coloring_algorithm=GreedyColoringAlgorithm(),
)

x = ones(600)
y = zeros(600)

context = Constant((
    integrator = Tsit5(),
    times = collect(LinRange(0.0, 1.0, 101)),
    state_nodes = 101,
    state_size = 6,
));

jac_prep = prepare_jacobian(
    compute_dynamics_arcs_solve!,
    y,
    backend,
    x,
    context
);

jac_buffer = similar(sparsity_pattern(jac_prep), eltype(x))


jacobian!(
    compute_dynamics_arcs_solve!,
    y,
    jac_buffer,
    jac_prep,
    backend,
    x,
    context,
)

jacobian!(
    compute_dynamics_arcs_step_reinit!,
    y,
    jac_buffer,
    jac_prep,
    backend,
    x,
    context,
)

The reinit! approach was 10-20% faster for this example.