Hi, Iβve been working with multiple shooting for trajectory optimization. I am currently using the sparse backend from DifferentiationInterface.jl and using DifferentialEquations.jl to perform the integration. A snippet of my current code is inserted below:
function compute_dynamics_arcs!(
y::AbstractVector{T}, u::AbstractVector{T}, c
) where {T<:Real}
dyn = c.dyn
times = c.times
integrator = c.integrator
state_nodes = c.state_nodes
state_size = c.state_size
state_elements = (state_nodes - 1) * state_size
column_start = state_elements + 1
for i in 1:length(dyn.components)
if :times in fieldnames(typeof(dyn.components[i]))
control_size = length(dyn.components[i].control[1])
control_nodes = length(dyn.components[i].control)
control_elements = control_nodes * control_size
@views controls = [
SA[u[i:(i + control_size - 1)]...] for
i in column_start:control_size:(column_start + control_elements - 1)
]
dyn = @set dyn.components[i].control = controls
column_start += control_elements
end
end
# Using solve vs. step!/reinit! for AD tends to be faster as DifferentialEquations
# provides a lot of optimizations for the solve function
prob = ODEProblem(ode_dynamics, SA[u[1:state_size]...], (times[1], times[end]), dyn)
j = 1
for i in 1:state_size:(state_elements)
@views state = SA[u[i:(i + state_size - 1)]...]
cb = dynamics_callback(dyn, state, (times[j], times[j + 1]))
tstops = dynamics_tstops(dyn, (times[j], times[j + 1]))
y[i:(i + state_size - 1)] = solve(remake(prob; u0=state, tspan=(times[j], times[j + 1])), integrator; callback=cb, abstol=1e-8, reltol=1e-8, tstops=tstops, d_discontinuities=tstops, save_everystep=false).u[end]
j += 1
end
return nothing
end
BenchmarkTools.Trial: 1507 samples with 1 evaluation per sample.
Range (min β¦ max): 2.105 ms β¦ 42.442 ms β GC (min β¦ max): 0.00% β¦ 93.07%
Time (median): 2.282 ms β GC (median): 0.00%
Time (mean Β± Ο): 3.315 ms Β± 2.860 ms β GC (mean Β± Ο): 21.10% Β± 19.65%
βββββββββ
βββββββββββ
βββββββββββββββββββββββββββββββββββββββ
ββββββββ β
2.11 ms Histogram: log(frequency) by time 12.7 ms <
Memory estimate: 4.36 MiB, allocs estimate: 15871.
I am having issues in reducing the allocations / initialization for a speed up. Profiling the code, it seems that approximately half of the computational time is spent in the initialization of the DifferentialEquations solver. When I tried to use the step!/reinit! approach as may be reasonable, performance was worse, which I think is because DifferentialEquations may have optimized codepath(s) when differentiation is detected on solve! calls. I was wondering if there was any way to save the underlying caches between repeated solve! calls.
Any thoughts on how this code may be improved? Thanks!