Hi all,
I have been trying to reduce memory overhead as I found that in my systems of interest there were a lot of unnecessary reads/writes as the state was saved at each (specified) interval dt
. In fact, for my purposes, I only need the mean and the variance of some stochastic trajectories, so I proceeded to write an online algorithm for them using a SavingCallback
. But, I mistakenly assumed that the SavingCallback
could be instructed to only save a specific value, as it obviously stores the desired values for each of the given saveat
values. I have been trying to figure out how to avoid this, as I truly do not need to save anything other than the mean and the variance, so saving any intermediate state (either the state u
or intermediate values of the mean and variance) is useless for me and should be avoided in order to reduce the memory overhead of the solver. However, I have not been able to implement such a feature as of yet.
My ideas were to use a FunctionCallingCallback
instead of a SavingCallback
, but I would like for this function to be in-place as it may readily modify/overwrite the previous values of the means/variances. How can I do this? Is there a way to define a function f!(mean,variance)
that updates the mean and the variance based on the current state (u,t,integrator)
?
Example: the Ornstein-Uhlenbeck process
For reference, here is a module OrnsteinUhlenbeck
that models the Ornstein-Uhlenbeck process
for which the mean and variance are given by \textrm{mean}(x) = \mu and \textrm{Var}(x) = \sigma^2/2 (for t large). The code implements function that create the SDEProblem
and an EnsembleProblem
that, in this case, models multiple instances of the process. This is done as my original problem also changes other parameters of the SDEProblem
.
`OrnsteinUhlenbeck` module
#= Module for simulating stochastic dynamics (Brownian motion) and save/compute only the
rolling mean and variance of the state in order to save memory
=#
#/ Start module
module OrnsteinUhlenbeck
#/ Packages
using DifferentialEquations
using OhMyThreads: TaskLocalValue
using Statistics
using Random
#################
### FUNCTIONS ###
"Define Ornstein-Uhlenbeck process with some mean and variance"
function generate_sdeproblem(; μ=1.0, σ=0.1, tspan=(0.0, 1e4), x₀=0.0)
f(u,p,t) = μ - u
g(u,p,t) = σ
sprob = SDEProblem(f, g, x₀, tspan)
return sprob
end
"Generate EnsembleProblem"
function generate_ensembleproblem(
sprob::SDEProblem;
ntrajectories=2,
tstart = (sprob.tspan[end] - sprob.tspan[begin]) / 2, #~ start measuring at tstart
dt = 1.0
)
#/ Define TaskLocalValue that deepcopy's the problem to the task if it does not exist
#~ see: https://juliafolds2.github.io/OhMyThreads.jl/stable/literate/tls/tls/#TLV
tlv_prob = TaskLocalValue{SDEProblem}(() -> deepcopy(sprob))
T = Tuple{Float64, Float64, Float64}
saved_values = [SavedValues(Float64, T) for _ in 1:ntrajectories]
#/ Write saving callback that computes the rolling mean and variance
function savingcallback(saved_value)
τ = tstart
xmean = 0.0 #~ mean ⟨u⟩
sumofsquares = 0.0 #~ sum of squared differences Σ(u - ⟨u⟩)²
#~ Define Welford algorithm for rolling mean and variance
#!note: need to be properly weighted by the current time
function f(u, t, integrator)
#~ Compute timestep, update current time τ
Δt = t - τ
τ = τ + Δt
#~ Compute differences and update mean and sum of squared differences
du = u - xmean #~ u[t] - ⟨u⟩[t-1]
xmean = xmean + du * Δt / (τ - tstart)
dv = u - xmean #~ u[t] - ⟨u⟩[t]
sumofsquares = sumofsquares + du * dv * Δt
return xmean, sumofsquares, u
end
saveat = (tstart+dt):dt:sprob.tspan[end]
return SavingCallback(f, saved_value, saveat=saveat)
end
#/ Define prob_func that mutates the (local) SDEProblem
#~ here it only mutates the callback
function prob_func(prob, i, nrepeats)
#/ Get local SDEProblem that can be mutated safely
localprob = tlv_prob[]
localcallback = savingcallback(saved_values[i])
localprob = remake(localprob, callback=localcallback)
return localprob
end
#/ Define output_func that computes the true variance, and returns the mean and variance
function output_func(sol, i)
__mean = getindex.(saved_values[i].saveval, 1)[end]
__var = getindex.(saved_values[i].saveval, 2)[end]
total_time = sprob.tspan[end] - tstart
__var = __var / total_time
output = (__mean, __var)
return (output, false)
end
#/ Create EnsembleProblem
eprob = EnsembleProblem(
sprob, prob_func=prob_func, output_func=output_func, safetycopy=false
)
return eprob, saved_values
end
########################
### HELPER FUNCTIONS ###
"Solve an EnsembleProblem"
function solve_ensembleproblem(eprob::EnsembleProblem; ntrajectories=2)
esol = solve(
eprob, LambaEM(), EnsembleThreads(), trajectories=ntrajectories, save_start=false
)
return esol
end
end # module OrnsteinUhlenbeck
#/ End module
Running the code above, I indeed find for \mu=1 and \sigma=0.1:
julia> sprob = OrnsteinUhlenbeck.generate_sdeproblem();
julia> eprob, saved_values = OrnsteinUhlenbeck.generate_ensembleproblem(sprob);
julia> esol = OrnsteinUhlenbeck.solve_ensembleproblem(eprob);
julia> esol.u
2-element Vector{Tuple{Float64, Float64}}:
(0.998592191232546, 0.005270329850583074)
(1.00059816575466, 0.0055156836342161245)
so the mean and variance have the values that we expect in these 2 trajectories, but saved_values
contains a bunch of data that I do not need (see below).
At it’s core, the module implements Welford’s online algorithm for time series, which when using a SavingCallback
is defined as:
#/ Write saving callback that computes the rolling mean and variance
function savingcallback(saved_value)
τ = tstart
xmean = 0.0 #~ mean ⟨u⟩
sumofsquares = 0.0 #~ sum of squared differences Σ(u - ⟨u⟩)²
#~ Define Welford algorithm for rolling mean and variance
#!note: need to be properly weighted by the current time
function f(u, t, integrator)
#~ Compute timestep, update current time τ
Δt = t - τ
τ = τ + Δt
#~ Compute differences and update mean and sum of squared differences
du = u - xmean #~ u[t] - ⟨u⟩[t-1]
xmean = xmean + du * Δt / (τ - tstart)
dv = u - xmean #~ u[t] - ⟨u⟩[t]
sumofsquares = sumofsquares + du * dv * Δt
return xmean, sumofsquares, u
end
saveat = (tstart+dt):dt:sprob.tspan[end]
return SavingCallback(f, saved_value, saveat=saveat)
end
Obviously, this stores the rolling mean and variance (and the current state) at each interval defined in saveat
, but it is precisely this allocation that I would like to avoid (i.e., saved_values
contains a lot of unnecessary data that I do not need, as mentioned above).
Question
Is there a way to write a callback that modifies the moments in-place, or is there a way I can modify the SavingCallback
such that it is in-place? Note that the output of output_func
is the real output that I desire, as, in this case, esol.u
gives the desired output.
Any help is greatly appreciated! Also feel free to ask for any clarifications.