Speed of EnsembleProblem

Here I’ve written a simple MRE that generates 100 simulations of a linear ODE with 100 different controls/forcings. I find that simply writing a for-loop to generate this 100 simulations is 10x faster than generating these simulations as an EnsembleProblem. Is there a reason for this or perhaps my code is not optimized?

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using DifferentialEquations
using Random


# Simulate Data
function true_ode(du, u, control, t)
    x, y = u
    du[1] = -y
    du[2] = x
    du .+= control(t)
end

tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)

u0 = zeros(2)

# number of controlled trajectories
Nc = 100
rng = MersenneTwister(1111)
disc_controls = randn(rng, Nc, Nt)
controls = [linear_interpolation(saveat, disc_controls[i, :]) for i = 1:Nc]

@time begin
sols = zeros(2, Nt, Nc)
prob = ODEProblem(true_ode, u0, tspan, controls[1])
for i = 1:Nc
    prob = remake(prob, p=controls[i])
    sols[:, :, i] = Array(solve(prob, Tsit5(); saveat=saveat))
end
end

function prob_func(prob, i, repeat)
    remake(prob, p=controls[i])
end
ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
@time begin
sols2 = Array(solve(ensemble_prob, Tsit5(), trajectories=Nc; saveat=saveat))
nothing
end

ensemble_prob = EnsembleProblem(prob, prob_func=prob_func, safetycopy=false)

1 Like

Great! Now they both take approximately 11ms.

@ChrisRackauckas I had a quick follow-up question. What is a good way of combining ODE parameters together with a control input function?

One idea I had was to pass both the parameters and the control input through the ODE parameters argument as a Tuple p = (params, control). Here “control” is a function of time while “params” is a vector. But this causes an issue if I want to then optimize params because a Tuple parameter vector cannot be optimized with sensitivity methods.

Another option is to pass in p = [params; control] as a vector where here “control” is a set of values of the continuous control function saved at discrete time points. But this causes yet another problem, I now need to interpolate the discrete saved values of the control vector inside my ODE function.

Are there any other options to resolve this problem?

1 Like

Indeed that’s a fine way to do it, but it’ll only work with forward mode AD for now. The fix for this should be coming in about a month.