Mutable structs in DifferentialEquations.jl parameters to cache data


I have a model implemented in DifferentialEquations.jl that saves lot of states. Some are just temporary for calculation or history, some are required for saving. Only a small subset of the states are the minimum required to actually solve the ODE. I organize all of this data as a mutable struct that is a container for all of the submodels and states. I save this mutable struct in my ODE parameters as p.sys for example, so that I can access it in the ODE func as well as the integrator interface for callbacks.

Each time step, I take the integrated state results and pack them where they belong in p.sys. All calculations are performed on p.sys in place. At the end of each time step, I pack only the minimum states I need into the ode func derivative output and save the states I care about with a saving callback. Everything seems to work okay.

My question is very generic. Is this the right way to do this? Any performance downsides to this method, mutating fields of the ODE parameters each timestep? Does this methodology throw any red flags?

I think it’s a common approach, which should be fine. Often you could even use a plain NamedTuple or an immutable struct if all your cached variables are arrays or references which can be changed either way.

The other way to store cache variables is with a struct that has an implementation of function (s::YourStruct)(du, u, p, t) which then serves as the RHS of your ODE, see An Implicit/Explicit CUDA-Accelerated Solver for the 2D Beeler-Reuter Model · DifferentialEquations.jl.
(I personally prefer the approach you use. Not 100% sure if there are drawbacks.)

Regarding the mutation of the ODE parameters, it might be good to keep in mind that for some integrators, the evaluation will not be monotone with respect to time (for some models, that matters),


Thanks @SteffenPL ! That Beeler-Reuter example is a really interesting idea. In my gut, it felt weird to mutate the parameters, hence the post. Glad to see I wasn’t too far off in my approach, but the function-struct approach syntactically feels a little better. I know you said you weren’t 100% sure if there were any drawbacks, but any reason in particular you prefer the mutate-struct-in-parameters approach? Would the function-struct approach solve your last point about the evaluation not being monotone with respect to time?

Shortly after hitting enter on this, I think I realized one drawback of the function (::mystruct)(du,u,p,t) approach. I save many of the fields of mystruct each timestep using a saving callback. I’m only able to do that because sys lives in p, which lives in the integrator interface. In the function-struct approach, I believe the fields would not be accessible from the integrator interface, and thus can’t be accessed by a saving callback. Perhaps there’s other ways to save the fields at each timestep other than the saving callback, for example GitHub - jonniedie/SimulationLogs.jl: Signal logging and scoping for DifferentialEquations.jl simulations..

It’s there in the integrator, but it’s been wrapped into oblivion, first in a FunctionWrapper (which itself nests it inside a SciMLBase.Void inside a RefValue), then in a FunctionWrappersWrapper, and then in an ODEFunction. I couldn’t find an officially supported way to unwrap all of this, but with the current versions of the packages, the following works: integrator.f.f.fw[1].obj[].f.

using OrdinaryDiffEq

struct ODE
    ODE() = new(Float64[])

(::ODE)(du, u, p, t) = (du .= -u)

f = ODE()
prob = ODEProblem(f, [1.0], (0.0, 1.0))
callback = let
    condition(u, t, integrator) = true
    function affect!(integrator)
        push!(integrator.f.f.fw[1].obj[], integrator.t)
        u_modified!(integrator, false)
    DiscreteCallback(condition, affect!; save_positions=(true, false))

sol = solve(prob, Tsit5(); callback)
4-element Vector{Float64}:

But given this level of inception, it’s probably wiser to (ab)use p for this. Just make sure you only modify it in a callback and not from the ODE function itself.

1 Like

Oof, okay yea I think the p interface is certainly better. However, I am definitely calculating and mutating all the values of p.sys in the ODE func. So far all of the results have been very reasonable. What issues should I be looking out for by modifying p in the ODE func?

I don’t know exactly what your p.sys is all about (maybe you could show a simplified example?), but the general problem is that the ODE function is called an unknown number of times per time step, with a lot of different values for u and t. For one there are the Runge-Kutta stages of the algorithm you’re using, possibly inside an iterative nonlinear solver if you’re using an implicit method; the function is sometimes called with dual numbers in u to calculate derivatives; and then there’s the adaptive time stepping, which means that the result of all these calls might be discarded and repeated with a smaller time step. I don’t know what you could do with p.sys from all these calling contexts that would make sense or be useful.

In short, anything you want to do at the actual values of u, t that comprise the ODE solution must be done in a callback, not within the ODE func.

Things like preallocated buffers to speed up the ODE function itself can of course be mutated from within the function. If that’s your use case, make sure to check out PreallocationTools.jl to ensure dual number support.

1 Like

Yea that makes sense. I think I might be ok currently in my use case, but I need to be more mindful of this going forward.

p.sys contains basically 3 kinds of information in my model.

  • constant parameter data (never mutated)
  • interesting data to save (not actually used in calcs, just placed here for access in the saving callback)
  • temporary matrices/arrays for calculations

I believe my results so far have been ok because all the fields that I mutate in p.sys to support calculation are either directly or indirectly dependent on u each time step/stage of RK4. The entire system gets re-derived from u, so the mutated fields just continuously get over written. If I had fields that just incremented each timestep for example, I could see how those results would be compromised.

Admittedly, I am currently forcing fixed step Tsit5 as the solver but I’d like to leave the model open to variable step in the future. For fixed step, I could run the model in a periodic callback, but I’m not sure how to set up a callback to run the model every timestep for a variable step solver.

Edit: FunctionCallingCallback or SavingCallbacks should do this for me. I’m not sure running the model in a callback works for me however. My ode function basically looks like

function odefunc!(du,u,p,t)
    p.sys.u  = u
    model!(p.sys) # all the calculations for du
    du = p.sys.du

If I run the model in the callback rather than the ode func, then the model is no longer part of the solver, and all I’m integrating is a constant du calculated in a callback. i.e.

function odefunc!(du,u,p,t)    
    du = p.sys.du

function affect!(integrator)

Or does the callback get called each stage of the solver?

Yeah, this is the preallocated buffer use case I mentioned. That’s fine of course—I didn’t mean to suggest you outsource the model calculations to callbacks. Just note that you might want to use DiffCache from PreallocationTools.jl instead of plain arrays if you ever want to do anything that requires autodiff (implicit algorithms, sensitivity analysis, et cetera), and be aware that this pattern is not thread-safe, so you should avoid parallelized algorithms (just go to ODE Solvers · DifferentialEquations.jl and search for “thread” to see what to avoid).

This is the one I’m saying should probably happen in the callback, not in the ODE function. When entering the SavingCallback, you have no idea what the values of u and t were the last time the ODE function was called, and they likely don’t equal any step of the ODE solution (although they could if you’re using an algorithm with the FSAL property). In addition, doing these calculations/saves for every Runge-Kutta stage is wasteful. Is there a reason you’ve been doing this in the ODE function rather than the callback?

This has been very helpful and interesting. I’ve spent some time reading through PreallocationTools.jl and have come up with a MWE (though perhaps not that minimal) that follows pretty closely to what my model is trying to do. The model runs allocation free, but if I try to use Rosenbrock23() I get an error about formatting the function. I’m still working through the model and trying to learn about about how to use DiffCaches appropriately, but I’ll leave the code here in case the mistake is obvious.

I am dependent on both parallelization and ForwardDiff at some point in this project, and it sounds like this will be the right way to set up the model.

using DifferentialEquations, PreallocationTools, StaticArrays

struct Model
    #many other fields
    X::DiffCache{MMatrix{3, 3, Float64, 9}, Vector{Float64}} # Maybe I cant define this so explicitly, in case it's dual type?       
    dummyMat::MMatrix{3,3,Float64,9}  # just used to define type of X  
    function Model(uindex,X)         
        dummyMat = @MMatrix zeros(3,3)
        return new(uindex,X,dummyMat)        

struct MySystem
    #many other fields

function run_models!(sys,u,du)        
    for model in sys.models
        x = SVector{3,Float64}(u[model.uindex])
        X = get_tmp(model.X,model.dummyMat)
        X .= 1e-2*x*x' #arbitrary function for this example

        du[model.uindex] .= X * x

function odefunc!(du,u,p,t)    

function simulate(original_sys,u; tspan=(0,10), solver=Tsit5()) #solver = Rosenbrock23() not working
    #deep copy to support restarting the sim with the same sys
    #this would go in the ensemble update function for parallel sims, maybe that solves parallelization?
    sys = deepcopy(original_sys) 
    p = (sys=sys,)   

    # saving X, X is required to solve for u, but isnt a part of u so doesnt get save
    # saved in Model so it's only calculated one time, then placed in the callback
    save_values = SavedValues(Float64, Tuple{SMatrix{3,3,Float64,9},SMatrix{3,3,Float64,9}})
    function save_function(u, t, integrator) 
        Tuple([get_tmp(integrator.p.sys.models[i].X,integrator.p.sys.models[i].dummyMat) for i in eachindex(integrator.p.sys.models)])    
    save_cb = SavingCallback(save_function, save_values)

    prob = ODEProblem(odefunc!, u, tspan, p)
    sol = solve(prob, solver, callback=save_cb)
    return sol, save_values

model1 = Model(SA[1,2,3], DiffCache(@MMatrix zeros(3,3)))
model2 = Model(SA[4,5,6], DiffCache(@MMatrix zeros(3,3)))

u = randn(6) #do I need to make these duals? or doesnt ForwardDiff or Rosenbrock convert for me?
du = randn(6)

sys = MySystem([model1,model2])

sol = simulate(sys,u)
julia> @btime odefunc!($du,$u,$p,$1.0)
  14.100 ns (0 allocations: 0 bytes)
julia> sol = simulate(sys,u,solver = Rosenbrock23())
ERROR: First call to automatic differentiation for the Jacobian
failed. This means that the user `f` function is not compatible
with automatic differentiation. Methods to fix this include:

1. Turn off automatic differentiation (e.g. Rosenbrock23() becomes
   Rosenbrock23(autodiff=false)). More details can befound at
2. Improving the compatibility of `f` with ForwardDiff.jl automatic
   differentiation (using tools like PreallocationTools.jl). More details
   can be found at
3. Defining analytical Jacobians. More details can be
   found at

Note: turning off automatic differentiation tends to have a very minimal
performance impact (for this use case, because it's forward mode for a
square Jacobian. This is different from optimization gradient scenarios).
However, one should be careful as some methods are more sensitive to
accurate gradients than others. Specifically, Rodas methods like `Rodas4`
and `Rodas5P` require accurate Jacobians in order to have good convergence,
while many other methods like BDF (`QNDF`, `FBDF`), SDIRK (`KenCarp4`),
and Rosenbrock-W (`Rosenbrock23`) do not. Thus if using an algorithm which
is sensitive to autodiff and solving at a low tolerance, please change the
algorithm as well.

MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50

  [1] jacobian!(J::Matrix{Float64}

Thanks for showing an example! I’ll mention a couple of issues, and then I’ll show a rewrite below.

You’re doing the compiler’s work here, trying to figure out the exact types of everything. It is better to use parametric types and let the compiler do the type inference.

Also, there’s rarely, if ever, a need to use preallocated buffers if you’re using StaticArrays. Better to work out of place in that case. To explain the principles I’ll show you my in-place version using regular arrays and preallocation first, but an out-of-place StaticArrays version would probably be a better way to solve this particular toy example.

This is where your use of PreallocationTools goes wrong and leads to Rosenbrock23() failing. For one, you shouldn’t convert the state vector to a regular type—if the function was called with a dual vector, it should remain a dual vector. Second, you’re using a fixed dummy array to get your cache, which means you will always get a regular cache and never a dual cache. The point is to use u or x as the second argument (doesn’t matter which), which will give back a regular cache when u is regular and a dual cache when u is dual.

I’m not experienced with ensemble problems, but yes, as long as the different systems in the ensemble don’t share p.sys you should be fine. The thread-safety issue I was talking about was to avoid using particular solvers that do parallel calls to the ode function for a single system, like KuttaPRK2p5.

Here’s what I’ve been talking about with the saving callback: you’re using X as-is from whatever it got set to in the last call to the ODE function in some Runge-Kutta stage or deep within an iterative solver. This doesn’t make sense. Whatever state functions you want to save, you need to compute them in the callback and not rely on them already having been computed in the ODE function.

However, a more fundamental problem here is that you’re just saving a reference to the same mutable object every time. If you check save_values.saveval after your model run, you’ll notice that it just repeats the same value over and over, because they’re all references to the same object. In other words, you need a copy here. (A minor issue is that you incur an unnecessary allocation by using an intermediate vetor inside the tuple comprehension.)

So, here’s my version. I changed a lot of stuff but I hope it’s not too hard to follow.

# More specific imports to reduce precompilation time
using DiffEqCallbacks, OrdinaryDiffEq, LinearAlgebra, PreallocationTools

struct Model{I,DC<:DiffCache}

    function Model(u, uindex)
        x = u[uindex]
        Xcache = DiffCache(x * x')
        return new{typeof(uindex),typeof(Xcache)}(uindex, Xcache)

struct MySystem{M<:Model}

# Factor out the computation of X such that it can be called
# both in run_model! and the callback
function computeX!(Xcache, x)
    X = get_tmp(Xcache, x)
    # x and X are not static arrays anymore, so we need
    # in-place mul! to avoid allocations. In this case, we could
    # also have used broadcasting like this:
    # @. X = -1e-2 * x * x'
    mul!(X, x, x', -1e-2, 0.0)  # arbitrary example function
    return X                    # minus sign for ODE stability

function run_models!(sys, u, du)
    # Use @views to avoid allocating copies
    @views for model in sys.models
        x = u[model.uindex]
        X = computeX!(model.Xcache, x)
        mul!(du[model.uindex], X, x)
    return nothing

function odefunc!(du, u, p, t)
    run_models!(p.sys, u, du)
    return nothing

function simulate(sys, u; tspan=(0, 10), solver=Tsit5())
    p = (; sys=deepcopy(sys))
    # Note: calling computeX! in the callback to get the
    # correct X for the current u, then making a copy for
    # saving; without the copy, you'd be saving the same object
    # every time
    _save(u, models) = Tuple(copy(computeX!(m.Xcache, u[m.uindex])) for m in models)
    save_function(u, t, integrator) = _save(u, integrator.p.sys.models)
    # Let the compiler/runtime figure out the types
    save_values = SavedValues(eltype(u), typeof(_save(u, p.sys.models)))
    save_cb = SavingCallback(save_function, save_values)

    prob = ODEProblem(odefunc!, u, tspan, p)
    sol = solve(prob, solver; callback=save_cb)
    return sol, save_values

u = randn(6)
model1 = Model(u, 1:3)  # ranges rather than StaticVector to
model2 = Model(u, 4:6)  # slice `u`

sys = MySystem([model1, model2])
sol, save_values = simulate(sys, u; solver=Rosenbrock23())
# Rosenbrock23() works!
julia> save_values.saveval  # note: saved values not identical
                            # manually truncated output
9-element Vector{Tuple{Matrix{Float64}, Matrix{Float64}}}:
 ([-0.009253406785263131 ...], [-0.003323371868835428  ...])
 ([-0.009237726677218405 ...], [-0.003321548604950785  ...])
 ([-0.00908379835781979  ...], [-0.0033034253806941308 ...])
 ([-0.008436898508490906 ...], [-0.0032225306381912793 ...])
 ([-0.007876025534443998 ...], [-0.0031455031885456778 ...])
 ([-0.006875134923279897 ...], [-0.0029891221286686242 ...])
 ([-0.006097597772280344 ...], [-0.0028470677929763014 ...])
 ([-0.00509025323868019  ...], [-0.0026286419444377426 ...])
 ([-0.004469969928921746 ...], [-0.0024693940452286947 ...])

Here’s an out-of-place version using StaticArrays throughout and no preallcated buffers. The model is still allocation-free and likely much faster than the in-place version with preallocations (though I didn’t benchmark). Note that since all values are immutable, there’s no need to copy on save here.

using DiffEqCallbacks, OrdinaryDiffEq, LinearAlgebra, StaticArrays

struct Model{I}

struct MySystem{N,M<:Model}
    # Type stable StaticArrays code requires a statically sized system,
    # hence using NTuple instead of Vector (StaticVector would also work)

computeX(x) = -1e-2 * x * x'
run_model(x) = computeX(x) * x
run_models(sys, u) = vcat((run_model(u[model.uindex]) for model in sys.models)...)
odefunc(u, p, t) = run_models(p.sys, u)

function simulate(sys, u; tspan=(0, 10), solver=Tsit5())
    p = (; sys=sys)   # No need for deepcopy here, everything is immutable

    _save(u, models) = Tuple(computeX(u[m.uindex]) for m in models)
    save_function(u, t, integrator) = _save(u, integrator.p.sys.models)
    save_values = SavedValues(eltype(u), typeof(_save(u, p.sys.models)))
    save_cb = SavingCallback(save_function, save_values)

    prob = ODEProblem(odefunc, u, tspan, p)
    sol = solve(prob, solver, callback=save_cb)
    return sol, save_values

u = @SArray randn(6)
model1 = Model(SA[1:3...])  # Need statically sized slices here
model2 = Model(SA[4:6...])

sys = MySystem((model1, model2))
sol, save_values = simulate(sys, u; solver=Rosenbrock23())
# Rosenbrock23() works!
julia> save_values.saveval
8-element Vector{Tuple{SMatrix{3, 3, Float64, 9}, SMatrix{3, 3, Float64, 9}}}:
 ([-0.007063487884081139  ...], [-0.00015181128592323527 ...])
 ([-0.00705432779722293   ...], [-0.0001518095124753917  ...])
 ([-0.006964016455622655  ...], [-0.00015179178027558386 ...])
 ([-0.006460034242636762  ...], [-0.00015168383299083235 ...])
 ([-0.006024087098977398  ...], [-0.00015157603913119408 ...])
 ([-0.0052527054950404104 ...], [-0.000151342135726838   ...])
 ([-0.004654138649966144  ...], [-0.00015110788903541719 ...])
 ([-0.004426261051150212  ...], [-0.00015100224372919784 ...])

@danielwe I really appreciate your responses here! I felt like I was 80% there (debatable) but didn’t quite understand the concepts well enough to put it all together. I think looking at your code, I had a lot of these ideas at one point or another in my model but looking at my code I can see I just mixed too many of them together without ever fully understanding the concept or finishing the idea. Seeing your code really brought a lot of these ideas home for me. Thanks again!

I think the main thing I was missing is this:

Indeed, I was thinking that I’m already calculating this quantity in the ode_func, why would I calculate it again? I see the error there now, but that is what led me down this path of using mutable containers to store everything to be accessed later in the SavingCallback

1 Like

I’ve been there myself but eventually came to terms with that little bit of redundant computation being the price of using a flexible library where you can plug in a wide range of algorithms into your solver. It’s also not an outrageous price to pay: If I’m not mistaken, Tsit5() calls the ode function 6 times per time step, so the overhead of the extra computation is never more than 1/6, which is of course not negligible but also not catastrophic.

(If you’re determined enough you may of course be able to hack your way out of the redundancy after all, but you should first make sure what you’re doing is correct, and only then possibly look for such optimizations at your peril (I’m starting to think of ways to do it, but they’re not pretty and I wouldn’t recommend).)

I should also mention that all these admonitions assume that you’re saving X because it’s a state function you’re interested in for its own sake. But maybe you just want to be able to inspect every single X produced during the solver run, no matter whether the corresponding (u, t) is an actual solution state or not (say, for debugging purposes)? In this case, you probably shouldn’t use a callback at all, but just push values directly from the ODE function to a storage array. Here’s a sketch of how it could look:

p = (; storage=[], ... #=rest of p=#)

function odefunc(u, p, t)
    X = computeX(u, p, t)
    push!(, (t, u, X))  # save everything from this call
    ....  # rest of odefunc

Note that you’ll have many more entries in than sol.u, and you have to look at the values of (u, t) to find correspondences.