Reduce the batches of a Parallel Ensemble Problem to the mean of the square modulus

Hi,

I’m performing a generic EnsembleProblem. Let’s take a simple example

prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0))

function prob_func(prob,i,repeat)
  remake(prob,u0=rand()*prob.u0)
end

ensemble_prob = EnsembleProblem(prob,prob_func=prob_func)
sim = solve(ensemble_prob,Tsit5(),EnsembleDistributed(),trajectories=100,batch_size = 20)

and I want that every batch_size steps it performs the timeseries_steps_mean of the square modulus of the soultions.

I can do it when the simulation is finished, by doing timestep_mean(abs2.(sim), 1:step), but how can I perform this after every batch? I think that the reduce function is the way, but I didn’t found a working method.

In the reduce function, just sum up the result of the batch divided by the batch length.

I tried something like

prob = ODEProblem((u,p,t)->1.01*u, [0.5,0.5], (0.0,1.0))

function prob_func(prob,i,repeat)
  remake(prob,u0=rand().*prob.u0)
end

function reduction(u,data,I)
    (u+sum(abs2.(data)),false)
end

ensemble_prob = EnsembleProblem(prob, prob_func=prob_func, reduction=reduction)
sim = solve(ensemble_prob,Tsit5(),EnsembleSerial(),trajectories=100,batch_size = 20)

But it doesn’t work. It gives me an error.

what’s the error?

MethodError: no method matching abs2(::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, 
Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, 
Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, 
SciMLBase.AutoSpecialize, var"#435#436", UniformScaling{Bool}, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{},
 NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, 
Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, 
OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#435#436", 
UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, 
Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, 
OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats})

data is a Vector{ODESolution}. You would need to broadcast that on each ODESolution. (u+sum(map(abs2,map(abs2,data))),false) is a straightforward way, but there are ways to make that nicer.

It gives me the error

MethodError: no method matching abs2(::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, 
Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, 
Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, 
SciMLBase.AutoSpecialize, var"#11#12", UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, 
Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, 
OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#11#12", 
UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, 
Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, 
OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats})
Closest candidates are:

I solved instead using the help of output_func:

prob = ODEProblem((u,p,t)->1.01*u, [0.5,0.5], (0.0,1.0))

function prob_func(prob,i,repeat)
  remake(prob,u0=rand().*prob.u0)
end

function output_func(sol, i)
  (hcat(map(x->abs2.(x), sol.u)...), false)
end

function reduction(u,batch,I)
  tmp = sum(cat(batch..., dims = 3), dims = 3)
  length(u) == 0 && return tmp, false
  cat(u, tmp, dims = 3), false
end

ensemble_prob = EnsembleProblem(prob, prob_func=prob_func, output_func=output_func, reduction=reduction)
sim = solve(ensemble_prob,Tsit5(),EnsembleSerial(),trajectories=100,batch_size = 20);
solution = sum(sim.u, dims = 3) ./ 100
1 Like

Your last code there looks fine and runs fine?

Yes it works. Anyway, thank you for your help!