Solving EnsembleProblem efficiently for large systems: memory issues

I apologize in advance for this length thread, but the problem setting is relatively complicated and this was the smallest MWE that I could think of, while still remaining clear.

Problem setting

I am encountering some issues with memory usage that is so big that my Julia processes get killed, yet I do not understand why so much memory is being allocated.

Essentially, I am trying to solve (generalized) Lotka-Volterra systems. We can consider the very simple form;

\frac{dx_i}{dt} = x_i (1 - x_i - \sum_{j\neq i} A_{ij} x_j)

I am using ModelingToolkit.jl in conjunction with DifferentialEquations.jl to create an EnsembleProblem and solve many instances, with random A_{ij} in parallel. As the system at hand is known for being chaotic, for now let us assume it suffices to assume that for the given parameters there is a unique fixed point to which all systems converge. Note that this fixed point obviously depends on all A_{ij}, but not on on the initial conditions. I am interested in the state of the system at some (large) time t, so I only save the final state of the solution β€” this should also save a lot of memory.

The code I am using to investigate this problem is comprised of some functions. The main two are to generate an ODEProblem and the EnsembleProblem, and there are some β€˜helper’ functions that update an existing problem and a callback that puts β€˜species’ with low abundance to zero, i.e. x_i = 0 when x_i < \varepsilon.

The code

The function to generate the ODEProblem:

function define_odeproblem(S, a; x0=rand(S), k=1.0, tspan=(0.0, 1_000.0))
    @variables t
    @variables (x(t))[1:S]
    @parameters A[1:S,1:S]
    D = Differential(t)

    eqns = [D(x[i]) ~ x[i] * (1.0 - sum([A[i,j]*x[j] for j in 1:S])) for i in 1:S]
    @named odesys = ODESystem(eqns, t)

    params = vec([A[i,j] => a[i,j] for i in 1:S, j in 1:S])
    u0 = [x[i] => x0[i] for i in 1:S]
    odeprob = ODEProblem(complete(odesys), u0, tspan, params)
    return odeprob 
end

The function to generate the EnsembleProblem:

function define_ensembleproblem(S, ΞΌ, Οƒ; nseeds=8, tspan=(0.0, 1_000.0))
    #/ Initialize interactions
    #~ (does not really matter how, as they'll be changed in set_interaction(..) anyways
    function generate_interactions(S, ΞΌ, Οƒ; seed=42)
	    a = ΞΌ/S .+ Οƒ*randn(Random.Xoshiro(1234*seed),S,S)/sqrt(S)
        a[CartesianIndex.(1:S, 1:S)] .= 1.0
        return a
    end
    amatrices = [generate_interactions(S, ΞΌ, Οƒ) for _ in 1:nseeds]

    #/ Define ODEProblem
    prob = define_odeproblem(S, amatrices[begin]; tspan=tspan)

    function set_interactions(prob, k, nrepeats)
        #~ Set interactions randomly and update existing ODEProblem
	    @unpack A = prob.f.sys
        Amap = vec([A[i,j] => amatrices[k][i,j] for i in 1:S, j in 1:S])
        return update_prob(prob, Amap)
    end

    #/ Define EnsembleProblem
    #~ Save only final state
    output_func(sol, i) = (last(sol), false)
    eprob = EnsembleProblem(prob, output_func=output_func, prob_func=set_interactions)
    return eprob, amatrices
end

and the β€˜helper’ functions

"Update ODEProblem"
function update_prob(prob, pmap)
    #~ Updates given ODEProblem with the new parameters given in pmap
    #~ Old parameters (not in pmap) remain the same
    params = ModelingToolkit.parameters(prob.f.sys)
    pdict = Dict(params[n] => prob.p[n] for n in 1:length(prob.p))
    p = ModelingToolkit.varmap_to_vars(pmap, params, defaults=pdict)
    updated_prob = remake(prob; p=p, u0=rand()*prob.u0)
    return updated_prob
end

"Callback to put extinct species to 0.0"
function cb_extinction(threshold::Float64)
	  #/ Create callback to let species with low abundances go extinct
    function condition(u, t, integrator)
	      any(u .< threshold)
    end
    function affect!(integrator)
	      integrator.u[integrator.u .< threshold] .= 0.0
    end
    return DiscreteCallback(condition, affect!)
end

The function that updates the problem is written such that, in principle, not all parameters need to be given for the problem to be updated. One can, for example, give only A_{11} and update the problem in that way. If there is a better way to write this, please let me know.

The problem

I am encountering issues when increasing the no. of species and/or the number of seeds. The problems are not with the time, but there are many memory allocations of which I do not know the origin. Intuitively, I would expect that perhaps some more temporary memory is used, however I have observed crashes when there is no more RAM available.

Consider the following example (I omitted all include’s for brevity)

julia> S = 32; ΞΌ = 4.0; Οƒ = 0.5;

julia> eprob, amatrices = define_ensembleproblem(S, ΞΌ, Οƒ; nseeds=32);

julia> esol = solve_ensembleproblem(eprob, length(amatrices))

julia> @benchmark solve_ensembleproblem(eprob, length(amatrices))
BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range (min … max):  516.055 ms … 989.800 ms  β”Š GC (min … max): 0.00% … 46.69%
 Time  (median):     539.820 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   582.574 ms Β± 143.338 ms  β”Š GC (mean Β± Οƒ):  7.93% Β± 14.77%

    β–‚β–ˆ                                                           
  β–…β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–… ▁
  516 ms           Histogram: frequency by time          990 ms <

 Memory estimate: 528.11 MiB, allocs estimate: 7988942.

julia> eprob, amatrices = define_ensembleproblem(S, ΞΌ, Οƒ; nseeds=64);

julia> esol = solve_ensembleproblem(eprob, length(amatrices));

julia> @benchmark solve_ensembleproblem(eprob, length(amatrices))
BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range (min … max):  1.096 s …    1.626 s  β”Š GC (min … max): 0.00% … 30.79%
 Time  (median):     1.130 s               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   1.222 s Β± 226.789 ms  β”Š GC (mean Β± Οƒ):  8.20% Β± 13.77%

  β–ˆ β–ˆβ–ˆβ–ˆ                                                    β–ˆ  
  β–ˆβ–β–ˆβ–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  1.1 s          Histogram: frequency by time         1.63 s <

 Memory estimate: 1.03 GiB, allocs estimate: 15977622.

First, the total memory that is used is, in my opinion, really large. I suspect that the ODEProblem that is defined using ModelingToolkit is not in-place (i.e., it is allocating)? Could this be the underlying issue? If so, how could this be fixed?
This would kind of makes sense as the memory estimate doubles when the no. of seeds doubles. However I explicitly tell the solver to not store anything but the final state, so I would guess that the memory should only increase slightly to accommodate the size of the solution (which is double the size).

Are there further ways I can improve the (memory) efficiency of this code? I am looking to model systems where both S and nseeds are in the order of 10^2 to 10^3, so I need a way to run these large models without crashing, at least.

No it generates non-allocating in-place code by default.

Factor the symbolic indexing parts out of the ensemble via setu.

This allocates, use a generator.

Other big source of memory usage is dense output. How are you calling the solve?

1 Like

Sorry, I apparantly did not write this call here, but it is:

esol = solve(
        eprob,
        Tsit5(), EnsembleThreads(),
        callback=cb, trajectories=nseeds, save_everystep=false
    )

Factor the symbolic indexing parts out of the ensemble via setu.

Is it me or this behavior basically not documented? I have found only this link, but I have not found a documented example on how to use it. Looking through the source code, I find in SymbolicIndexingInterface/A1VUA/parameter_indexing.jl that

"""
    setp(indp, sym)

Return a function that takes an index provider and a value, and sets the parameter `sym`
to that value. Note that `sym` can be an index, a symbolic variable, or an array/tuple of
the aforementioned.

Requires that the value provider implement [`parameter_values`](@ref) and the returned
collection be a mutable reference to the parameter object. In case `parameter_values`
cannot return such a mutable reference, or additional actions need to be performed when
updating parameters, [`set_parameter!`](@ref) must be implemented.
"""

I have no idea what this means. I need to define a function parameter_values and give this to setp?

This allocates, use a generator

I will look into this, is just using any(x -> x < threshold, u) instead? Interestingly, even if I omit the callback, I see no change in memory usage.

I think I figured out how to use setu/setp using some of Catalysts documentation, and have updated the code with the updated relevant parts being:

#/ Define ODEProblem
prob = define_odeproblem(S, amatrices[begin]; tspan=tspan)
#/ Define parameter setter
@unpack A = prob.f.sys
Asymb = vec([A[i,j] for i in 1:S, j in 1:S])
interactionsetter = ModelingToolkit.setp(prob, Asymb)

function set_interactions(prob, k, nrepeats)
    #~ Set interactions and which effectively updates existing ODEProblem
    interactionsetter(prob, amatrices[k])
    return prob
end

#/ Define EnsembleProblem
#~ Save only final state
output_func(sol, i) = (last(sol), false)
eprob = EnsembleProblem(prob, prob_func=set_interactions, output_func=output_func)

Still there is a lot of memory being used, especially compared to solving a single ODE, i.e. by letting

julia> prob = define_odeproblem(S, a)
julia> @btime sol = solve(prob, Tsit5(), save_everystep=false)
6.872 ms (212 allocations: 160.42 KiB)

Could it be that a lot of the memory is being allocated by copying the problem onto the different cores? Or is it solely the way I set the parameters that allocated? Or where is the allocation occuring? Running a profiler also indicates that there are some calls made to deepcopy(), which suggests that somethings are copied?

Full code here
using ModelingToolkit
using DifferentialEquations
using Random

function define_odeproblem(S, a; x0=rand(S), k=1.0, tspan=(0.0, 1_000.0))
    @variables t
    @variables (x(t))[1:S]
    @parameters A[1:S,1:S]
    D = Differential(t)

    eqns = [D(x[i]) ~ x[i] * (1.0 - sum([A[i,j]*x[j] for j in 1:S])) for i in 1:S]
    @named odesys = ODESystem(eqns, t)

    params = vec([A[i,j] => a[i,j] for i in 1:S, j in 1:S])
    u0 = [x[i] => x0[i] for i in 1:S]
    odeprob = ODEProblem(complete(odesys), u0, tspan, params)
    return odeprob 
end

"Callback to put extinct species to 0.0"
function cb_extinction(threshold::Float64)
    #/ Create callback to let species with low abundances go extinct
    function condition(u, t, integrator)
        any(x -> (x < threshold), u)
    end
    function affect!(integrator)
        for i in eachindex(integrator.u)
            if integrator.u[i] < threshold
                integrator.u[i] = 0.0
            end
        end
    end
    return DiscreteCallback(condition, affect!)
end

"Run system for nseeds times"
function define_ensembleproblem(S, ΞΌ, Οƒ; nseeds=8, tspan=(0.0, 10_000.0))
    #/ Initialize interactions
    #~ (does not really matter how, as they'll be changed in set_interaction(..) anyways
    function generate_interactions(S, ΞΌ, Οƒ; seed=42)
	      a = ΞΌ/S .+ Οƒ*randn(Random.Xoshiro(1234*seed),S,S)/sqrt(S)
        for i in 1:S
            a[i,i] = 1.0
        end
        return a
    end
    amatrices = [generate_interactions(S, ΞΌ, Οƒ; seed=i) for i in 1:nseeds]

    #/ Define ODEProblem
    prob = define_odeproblem(S, amatrices[begin]; tspan=tspan)
    #/ Define parameter setter
    @unpack A = prob.f.sys
    Asymb = vec([A[i,j] for i in 1:S, j in 1:S])
    interactionsetter = ModelingToolkit.setp(prob, Asymb)

    function set_interactions(prob, k, nrepeats)
        #~ Set interactions and which effectively updates existing ODEProblem
        interactionsetter(prob, amatrices[k])
        return prob
    end

    #/ Define EnsembleProblem
    #~ Save only final state
    output_func(sol, i) = (last(sol), false)
    eprob = EnsembleProblem(prob, prob_func=set_interactions, output_func=output_func)
    return eprob, amatrices
end
    
function solve_ensembleproblem(eprob, nseeds; threshold=1e-6, cb=cb_extinction(threshold))
    #/ Solve
    esol = solve(
        eprob,
        AutoTsit5(Rosenbrock23()), EnsembleThreads(), callback=cb,
        trajectories=nseeds, save_everystep=false
    )
    return esol
end

@cryptic.ax we should improve the doc example based on this feedback.

Ok so now I am pretty sure that the β€œculprit” is the deepcopy that is called when safetycopy=true, which I currently need as my prob_func is not thread safe. However, as my ODEProblems are quite large when the no. of species is large (mostly because interactions scale with S^2), this creates quite a large memory overhead β€” to the point that I sometimes run out of memory with S and/or nseeds large enough.

Is it possible to make a prob_func that changes the problem’s parameters while being thread safe? In that case, I can set safetycopy=false and be done. I know my current implementation is not thread safe as my solutions contain duplicates, but I am currently not sure how to make it thread safe, or if it’s at all possible.

My current implementation of prob_func is:

#/ Define state- and parameter-setters using setu/setp
@unpack x, A = prob.f.sys
xsymb = vec([x[i] for i in 1:S])
Asymb = vec([A[i,j] for i in 1:S, j in 1:S])
statesetter = ModelingToolkit.setu(prob, xsymb)
interactionsetter = ModelingToolkit.setp(prob, Asymb)

function set_interactions(prob, k, nrepeats)
    #/ Update prob by setting new interactions and initial conditions
    #~ Get SxS interaction matrix and new init state
    _a = generate_interactions(S, ΞΌ, Οƒ; seed=k)
    _u = rand(Random.Xoshiro(k), S)
    #~ Set state and interactions
    statesetter(prob, _u)
    interactionsetter(prob, _a)
    return prob
end

Is it perhaps possible to deepcopy the original ODEProblem to all threads and then use the above definition for prob_func? The idea is that you just need seperate copies of the (original) prob, but then you can change the parameters in-place as each thread is making serial calls to prob_func whenever it has finished it last call. Or is this more into the realm of EnsembleDistributed instead of EnsembleThreads?

Yes it’s a safety thing but you definitely want to get rid of that assumption when you need performance.

Yup that’s a good way to handle it. One problem per thread, and each is modifying its own so each thread is non-allocating. Though not necessarily thread because threads can migrate, so it would be per Task, but same idea.

EnsembleDistributed could give you better performance if you’re allocating a lot, since each process has its own GC. So for very large ensembles and large core counts it can perform better than EnsembleThreads if your code is not avoiding all allocations.

That is good to know. I am not quite sure how to do it yet, as I think I cannot simply use threadid() to index, say, a shared array with deepcopies of the original problem. I will look into tasks and how to create an ODEProblem for each of the tasks that can be altered in-place.

Ideally I don’t want to allocate anything, as I just need the final state. The callback should also be non-allocating. So in this case there should be no need to set safetycopy=true other than to avoid race conditions, so the overhead of EnsembleDistributed() (and the fact that I am using a single machine with many cores) is probably not worth it.

I have a potential solution that appears to work. The principle is relatively simple

  1. Create a dictionary (or any β€˜map’) that maps the task ID to an ODEProblem ascribed to that specific task. The ODEProblem is deepcopyd once in the beginning.
  2. In the prob_func, get the task ID using current_task(), deepcopy the original ODEProblem if it does not exist, otherwise mutate it.

The code looks like this, with some locks in place to prevent multiple tasks/threads writing to the dictionary at the same time:

#/ Create a dictionary to store task-local problems
tproblems = Dict{Task, ODEProblem}()
tlock = Threads.ReentrantLock()

function set_interactions(prob, k, nrepeats)
    #~ Update prob by setting new interactions and initial conditions
    tid = current_task()
    if !haskey(tproblems, tid)
        Threads.lock(tlock)
        tproblems[tid] = deepcopy(prob)
        Threads.unlock(tlock)
    end
    lprob = tproblems[tid]
    interactionsetter(lprob, amatrices[k])
    statesetter(lprob, rand(S))
    return lprob
end

with the statesetter and interactionsetter as above. This seems to work and indeed uses as much memory as the initial set of deepcopy’s, but nothing more.

If this is not the way to go I’d be happy to learn how I should otherwise implement this.

Not saying your are doing anything wrong but a simpler way of achieving this would be to use TaskLocalStorage from OhMyThreads.jl :slight_smile:

1 Like

Indeed seems to work the same way, but somehow it does give me slightly better performance! Code is now something like this:

#/ Define TaskLocalValue that deepcopy's the problem to the task if it does not exist
#~ see: https://juliafolds2.github.io/OhMyThreads.jl/stable/literate/tls/tls/#TLV
tlv_prob = TaskLocalValue{ODEProblem}(() -> deepcopy(prob))    

function set_interactions(prob, k, nrepeats)
    #~ Get local problem
    localprob = tlv_prob[]
    #~ Update prob by setting new interactions and initial conditions
    interactionsetter(localprob, amatrices[k])
    statesetter(localprob, rand(S))
    return localprob
end