Storing only rolling mean and variance when solving SDEProblems

Hi all,
I have been trying to reduce memory overhead as I found that in my systems of interest there were a lot of unnecessary reads/writes as the state was saved at each (specified) interval dt. In fact, for my purposes, I only need the mean and the variance of some stochastic trajectories, so I proceeded to write an online algorithm for them using a SavingCallback. But, I mistakenly assumed that the SavingCallback could be instructed to only save a specific value, as it obviously stores the desired values for each of the given saveat values. I have been trying to figure out how to avoid this, as I truly do not need to save anything other than the mean and the variance, so saving any intermediate state (either the state u or intermediate values of the mean and variance) is useless for me and should be avoided in order to reduce the memory overhead of the solver. However, I have not been able to implement such a feature as of yet.

My ideas were to use a FunctionCallingCallback instead of a SavingCallback, but I would like for this function to be in-place as it may readily modify/overwrite the previous values of the means/variances. How can I do this? Is there a way to define a function f!(mean,variance) that updates the mean and the variance based on the current state (u,t,integrator)?

Example: the Ornstein-Uhlenbeck process

For reference, here is a module OrnsteinUhlenbeck that models the Ornstein-Uhlenbeck process

dx_t = (\mu - x_t)dt + \sigma dW_t

for which the mean and variance are given by \textrm{mean}(x) = \mu and \textrm{Var}(x) = \sigma^2/2 (for t large). The code implements function that create the SDEProblem and an EnsembleProblem that, in this case, models multiple instances of the process. This is done as my original problem also changes other parameters of the SDEProblem.

`OrnsteinUhlenbeck` module
#= Module for simulating stochastic dynamics (Brownian motion) and save/compute only the
   rolling mean and variance of the state in order to save memory
=#
#/ Start module
module OrnsteinUhlenbeck

#/ Packages
using DifferentialEquations
using OhMyThreads: TaskLocalValue
using Statistics
using Random

#################
### FUNCTIONS ###
"Define Ornstein-Uhlenbeck process with some mean and variance"
function generate_sdeproblem(; μ=1.0, σ=0.1, tspan=(0.0, 1e4), x₀=0.0)
    f(u,p,t) = μ - u
    g(u,p,t) = σ
    sprob = SDEProblem(f, g, x₀, tspan)
    return sprob
end

"Generate EnsembleProblem"
function generate_ensembleproblem(
    sprob::SDEProblem;
    ntrajectories=2,
    tstart = (sprob.tspan[end] - sprob.tspan[begin]) / 2,  #~ start measuring at tstart
    dt = 1.0
)
    #/ Define TaskLocalValue that deepcopy's the problem to the task if it does not exist
    #~ see: https://juliafolds2.github.io/OhMyThreads.jl/stable/literate/tls/tls/#TLV
    tlv_prob = TaskLocalValue{SDEProblem}(() -> deepcopy(sprob))
    
    T = Tuple{Float64, Float64, Float64}
    saved_values = [SavedValues(Float64, T) for _ in 1:ntrajectories]

    #/ Write saving callback that computes the rolling mean and variance
    function savingcallback(saved_value)
        τ = tstart
        xmean = 0.0          #~ mean ⟨u⟩
        sumofsquares = 0.0   #~ sum of squared differences Σ(u - ⟨u⟩)²
        #~ Define Welford algorithm for rolling mean and variance
        #!note: need to be properly weighted by the current time
        function f(u, t, integrator)
            #~ Compute timestep, update current time τ
            Δt = t - τ
            τ = τ + Δt
            #~ Compute differences and update mean and sum of squared differences
            du = u - xmean  #~ u[t] - ⟨u⟩[t-1]
            xmean = xmean + du * Δt / (τ - tstart)
            dv = u - xmean  #~ u[t] - ⟨u⟩[t]
            sumofsquares = sumofsquares + du * dv * Δt
            return xmean, sumofsquares, u
        end
        saveat = (tstart+dt):dt:sprob.tspan[end]
        return SavingCallback(f, saved_value, saveat=saveat)
    end

    #/ Define prob_func that mutates the (local) SDEProblem
    #~ here it only mutates the callback
    function prob_func(prob, i, nrepeats)
        #/ Get local SDEProblem that can be mutated safely
        localprob = tlv_prob[]
        localcallback = savingcallback(saved_values[i])
        localprob = remake(localprob, callback=localcallback)
        return localprob
    end

    #/ Define output_func that computes the true variance, and returns the mean and variance
    function output_func(sol, i)
        __mean = getindex.(saved_values[i].saveval, 1)[end]
        __var = getindex.(saved_values[i].saveval, 2)[end]
        total_time = sprob.tspan[end] - tstart
        __var = __var / total_time
        output = (__mean, __var)
        return (output, false)
    end

    #/ Create EnsembleProblem
    eprob = EnsembleProblem(
        sprob, prob_func=prob_func, output_func=output_func, safetycopy=false
    )
    return eprob, saved_values
end

########################
### HELPER FUNCTIONS ###
"Solve an EnsembleProblem"
function solve_ensembleproblem(eprob::EnsembleProblem; ntrajectories=2)
    esol = solve(
        eprob, LambaEM(), EnsembleThreads(), trajectories=ntrajectories, save_start=false
    )
    return esol
end

end # module OrnsteinUhlenbeck
#/ End module

Running the code above, I indeed find for \mu=1 and \sigma=0.1:

julia> sprob = OrnsteinUhlenbeck.generate_sdeproblem();

julia> eprob, saved_values = OrnsteinUhlenbeck.generate_ensembleproblem(sprob);

julia> esol = OrnsteinUhlenbeck.solve_ensembleproblem(eprob);

julia> esol.u
2-element Vector{Tuple{Float64, Float64}}:
 (0.998592191232546, 0.005270329850583074)
 (1.00059816575466, 0.0055156836342161245)

so the mean and variance have the values that we expect in these 2 trajectories, but saved_values contains a bunch of data that I do not need (see below).

At it’s core, the module implements Welford’s online algorithm for time series, which when using a SavingCallback is defined as:

#/ Write saving callback that computes the rolling mean and variance
    function savingcallback(saved_value)
        τ = tstart
        xmean = 0.0          #~ mean ⟨u⟩
        sumofsquares = 0.0   #~ sum of squared differences Σ(u - ⟨u⟩)²
        #~ Define Welford algorithm for rolling mean and variance
        #!note: need to be properly weighted by the current time
        function f(u, t, integrator)
            #~ Compute timestep, update current time τ
            Δt = t - τ
            τ = τ + Δt
            #~ Compute differences and update mean and sum of squared differences
            du = u - xmean  #~ u[t] - ⟨u⟩[t-1]
            xmean = xmean + du * Δt / (τ - tstart)
            dv = u - xmean  #~ u[t] - ⟨u⟩[t]
            sumofsquares = sumofsquares + du * dv * Δt
            return xmean, sumofsquares, u
        end
        saveat = (tstart+dt):dt:sprob.tspan[end]
        return SavingCallback(f, saved_value, saveat=saveat)
    end

Obviously, this stores the rolling mean and variance (and the current state) at each interval defined in saveat, but it is precisely this allocation that I would like to avoid (i.e., saved_values contains a lot of unnecessary data that I do not need, as mentioned above).

Question

Is there a way to write a callback that modifies the moments in-place, or is there a way I can modify the SavingCallback such that it is in-place? Note that the output of output_func is the real output that I desire, as, in this case, esol.u gives the desired output.

Any help is greatly appreciated! Also feel free to ask for any clarifications.

Not easily? But it wouldn’t be hard to cook up such a callback using SavingCallback’s code. It could be a nice one to add to DiffEqCallbacks.jl, but yeah I don’t think there’s a pre-cooked callback that does this.

You can do this by just writing your own DiscreteCallback and it’s not much more work than using SavingCallback:

using StochasticDiffEq
using OhMyThreads: TaskLocalValue
using Statistics
using Random

#################
### FUNCTIONS ###
"Define Ornstein-Uhlenbeck process with some mean and variance"
function generate_sdeproblem(; μ=1.0, σ=0.1, tspan=(0.0, 1e4), x₀=0.0)
    f(u,p,t) = μ - u
    g(u,p,t) = σ
    sprob = SDEProblem(f, g, x₀, tspan)
    return sprob
end

"Generate EnsembleProblem"
function generate_ensembleproblem(
    sprob::SDEProblem;
    ntrajectories=2,
    tstart = (sprob.tspan[end] - sprob.tspan[begin]) / 2,  #~ start measuring at tstart
    dt = 1.0
)
    #/ Define TaskLocalValue that deepcopy's the problem to the task if it does not exist
    #~ see: https://juliafolds2.github.io/OhMyThreads.jl/stable/literate/tls/tls/#TLV
    mk_prob = () -> deepcopy(sprob)
    tlv_prob = TaskLocalValue{SDEProblem}(mk_prob)
    
    T = @NamedTuple{xmean::Float64, sumofsquares::Float64, u::Float64}

    saveat = (tstart+dt):dt:sprob.tspan[end]
    
    #/ Write saving callback that computes the rolling mean and variance

    accumulators = map(1:ntrajectories) do i
        Ref{T}()
    end
    
    function savingcallback(accumulator)
        τ = tstart
        xmean = 0.0          #~ mean ⟨u⟩
        sumofsquares = 0.0   #~ sum of squared differences Σ(u - ⟨u⟩)²
        #~ Define Welford algorithm for rolling mean and variance
        #!note: need to be properly weighted by the current time
        function f(u, t, integrator)
            #~ Compute timestep, update current time τ
            Δt = t - τ
            τ = τ + Δt
            #~ Compute differences and update mean and sum of squared differences
            du = u - xmean  #~ u[t] - ⟨u⟩[t-1]
            xmean = xmean + du * Δt / (τ - tstart)
            dv = u - xmean  #~ u[t] - ⟨u⟩[t]
            sumofsquares = sumofsquares + du * dv * Δt
            accumulator[] = (; xmean, sumofsquares, u)
        end
        saveat_set = Set(saveat)
        condition(u, t, integrator) = t ∈ saveat_set
        affect!(integrator) = f(integrator.u[end], integrator.t, integrator)
        DiscreteCallback(condition, affect!)
    end

    #/ Define prob_func that mutates the (local) SDEProblem
    #~ here it only mutates the callback
    function prob_func(prob, i, nrepeats)
        #/ Get local SDEProblem that can be mutated safely
        localprob = tlv_prob[]
        localcallback = savingcallback(accumulators[i])
        localprob = remake(localprob, callback=localcallback, tstops=saveat)
        return localprob
    end

    #/ Define output_func that computes the true variance, and returns the mean and variance
    function output_func(sol, i)
        total_time = sprob.tspan[end] - tstart
        
        __mean = accumulators[i][].xmean
        __var = accumulators[i][].sumofsquares / total_time
        
        output = (__mean, __var)
        return (output, false)
    end

    #/ Create EnsembleProblem
    eprob = EnsembleProblem(
        sprob, prob_func=prob_func, output_func=output_func, safetycopy=false
    )
    return eprob, accumulators
end

########################
### HELPER FUNCTIONS ###
"Solve an EnsembleProblem"
function solve_ensembleproblem(eprob::EnsembleProblem; ntrajectories=2)
    esol = solve(
        eprob, LambaEM(), EnsembleThreads(), trajectories=ntrajectories, save_start=false
    )
    return esol
end

and then

julia> begin
       sprob = generate_sdeproblem();
       eprob, saved_values = generate_ensembleproblem(sprob);
       esol = solve_ensembleproblem(eprob);
       esol.u
       end
2-element Vector{Tuple{Float64, Float64}}:
 (0.9993074606759617, 0.005608397563501149)
 (1.0001460814306193, 0.00533584458693052)

this gives the same stats, but without storing the whole execution history, and just updating one Ref per trajectory at each timestep.


Here’s a diff of the relevant changes:

-using DifferentialEquations
+using StochasticDiffEq
 using OhMyThreads: TaskLocalValue
 using Statistics
 using Random
@@ -24,9 +24,11 @@ function generate_ensembleproblem(
     #~ see: https://juliafolds2.github.io/OhMyThreads.jl/stable/literate/tls/tls/#TLV
     tlv_prob = TaskLocalValue{SDEProblem}(() -> deepcopy(sprob))
     
-    T = Tuple{Float64, Float64, Float64}
-    saved_values = [SavedValues(Float64, T) for _ in 1:ntrajectories]
-
+    T = @NamedTuple{xmean::Float64, sumofsquares::Float64, u::Float64}
+    saved_values = [Ref{T}() for _ ∈ 1:ntrajectories]
+    
+    saveat = (tstart+dt):dt:sprob.tspan[end]
+    
     #/ Write saving callback that computes the rolling mean and variance
     function savingcallback(saved_value)
         τ = tstart
@@ -43,10 +45,12 @@ function generate_ensembleproblem(
             xmean = xmean + du * Δt / (τ - tstart)
             dv = u - xmean  #~ u[t] - ⟨u⟩[t]
             sumofsquares = sumofsquares + du * dv * Δt
-            return xmean, sumofsquares, u
+            saved_value[] = (; xmean, sumofsquares, u)
         end
-        saveat = (tstart+dt):dt:sprob.tspan[end]
-        return SavingCallback(f, saved_value, saveat=saveat)
+        saveat_set = Set(saveat)
+        condition(u, t, integrator) = t ∈ saveat_set
+        affect!(integrator) = f(integrator.u[end], integrator.t, integrator)
+        DiscreteCallback(condition, affect!)
     end
 
     #/ Define prob_func that mutates the (local) SDEProblem
@@ -55,16 +59,17 @@ function generate_ensembleproblem(
         #/ Get local SDEProblem that can be mutated safely
         localprob = tlv_prob[]
         localcallback = savingcallback(saved_values[i])
-        localprob = remake(localprob, callback=localcallback)
+        localprob = remake(localprob, callback=localcallback, tstops=saveat)
         return localprob
     end
 
     #/ Define output_func that computes the true variance, and returns the mean and variance
     function output_func(sol, i)
-        __mean = getindex.(saved_values[i].saveval, 1)[end]
-        __var = getindex.(saved_values[i].saveval, 2)[end]
         total_time = sprob.tspan[end] - tstart
-        __var = __var / total_time
+        
+        __mean = saved_values[i][].xmean
+        __var = saved_values[i][].sumofsquares / total_time
+        
         output = (__mean, __var)
         return (output, false)
     end
1 Like

This isn’t quite right because you don’t want to force steps. You want to instead check and back interpolate.

1 Like

Thanks for the suggestions, I really need to how Refs work, as I am not yet familiar with them. In the meantime, I also have managed to find an implementation that relies on a FunctionCallingCallback that simply updates a value in an array. The relevant code is as follows:

function functioncallback(saved_value)
        du = zeros(Float64, S)
        #~ Define Welford's algorithm for rolling mean and variance
        #!note: need to be properly weighted by the current time
        function f(u, t, integrator)
            #~ Compute differences and update mean and sum of squared differences
            du .= u .- saved_value[:,1]
            saved_value[:,1] .= saved_value[:,1] .+ du .* (dt / (t - tstart))
            saved_value[:,2] .= saved_value[:,2] .+ du .* (u .- saved_value[:,1]) .* dt
            return nothing
        end
        #!note: start at tstart=dt to avoid Δt=0
        funcat = (tstart+dt):dt:prob.tspan[end]
        return FunctionCallingCallback(f, funcat=funcat)
    end

This seems to allocate the same amount regardless of dt, which is exactly the behavior that I would expect. It also appears to be thread-safe, but I am not sure whether it actually is…

This looks fine

1 Like