I’m using DifferentialEquations.jl
to solve an SDEProblem
with a custom output_func
. I am then trying to summarize multiple trajectories with EnsembleSummary
. However, the resulting summary has the wrong dimensions for the solution array (seems to be collapsing everything?), and I can’t figure out why or how I should change it.
Here is a toy example of what I am doing:
using DifferentialEquations
#Define a simple sde model
function f(du, u, p, t)
du[1] = 1.01*u[1]
du[2] = 0.99*u[2]
du[3] = u[2]*u[1] - u[3]
end
#Some noise
function noise!(du, u, p, t)
du .= 0.1*u
end
u0 = [0.5, 1.0, 1.2]
prob = SDEProblem(f, noise!, u0, (0.0, 1.0))
#Add a custom output function that should return 4 values
function output_func(sol, i)
#Some 1 dimensional outputs
output1 = 3*sol[1,:]
output2 = sol[1,:] + 2*sol[2,:]
#My problem has several of these where i am using multiple entries of
#solution at the same time
output3 = 0.5*sol[1:2,:] - 0.4*sol[2:3,:]
(u = [output1, output2, output3], t = sol.t), false
end
#Works for running the problem
ensemble_prob = EnsembleProblem(prob; output_func=output_func)
sim = solve(ensemble_prob, SOSRI(), EnsembleThreads(); trajectories=10, saveat=0.1)
#However summ.u returns an object of size (11, 2)! (should be (11,4))
summ = EnsembleSummary(sim)
This is the problem:
julia> size(summ.u)
(11, 2)
My question is: How should I structure the output_func
so that EnsembleSummary
combines trajectories correctly?
Additional info
This is my summ.u
which for some reason is ignoring output_3
(even if you change output_3
to be a vector and not a matrix).
julia> summ.u
t: 11-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
u: 2-element Vector{Vector{Float64}}:
[1.5, 1.6779729999930952, 1.882802689892928, 2.0791973541348807, 2.289813760785018, 2.522292007510237, 2.7381454998246877, 3.0130258988579692, 3.3164391128403614, 3.6792412500574008, 4.0170165276054375]
[2.5, 2.7426439877742514, 3.037732776506048, 3.2959205715740034, 3.63409275722072, 3.979916259016065, 4.370649184574971, 4.792145677419358, 5.3130449045579144, 5.949865231659005, 6.497619744969697]
And this is additional info on the differential equations package:
julia> pkgversion(DifferentialEquations)
v"7.16.1"