Efficient way to handle operator splitting (periodically updating solution mid-solve) in JuliaDiffEq

I am writing a simple model of atmospheric chemistry in Julia using DifferentialEquations.jl to solve the chemical kinetics ODEs.

While chemical kinetics can be modeled using a system of ODEs, some processes in modern atmospheric models cannot be expressed in terms of differential equations, i.e. time-dependent emissions, sub-grid scale processes, etc. They can be thought as simple if-else clauses that need to be executed to mutate the solution array (species concentrations) at given time steps. Thus the ODE solve needs to be ran step-by-step in conjunction with these if-else clauses, i.e. “operator splitting”.

Thus I am trying to find a performance-efficient way to mutate the solution array in the middle of solving the ODE. e.g. the code below for a simple stratospheric mechanism

function strato_t!(du, u, p, t)
    O, O1D, O3, NO, NO2, M, O2 = u  # species
    param = p(t)
    SUN = param[1]
    du[1] = 2*(2.643E-10)*SUN*SUN*SUN*O2 - 1*(8.018E-17)*O*O2 + 1*(6.120E-04)*SUN*O3 - 1*(1.576E-15)*O*O3 + 1*(7.110E-11)*O1D*M - 1*(1.069E-11)*NO2*O + 1*(1.289E-02)*SUN*NO2 # O
    du[2] = 1*(1.070E-03)*SUN*SUN*O3 - 1*(7.110E-11)*O1D*M - 1*(1.200E-10)*O1D*O3 # O1D
    du[3] = 1*(8.018E-17)*O*O2 - 1*(6.120E-04)*SUN*O3 - 1*(1.576E-15)*O*O3 - 1*(1.070E-03)*SUN*SUN*O3 - 1*(1.200E-10)*O1D*O3 - 1*(6.062E-15)*NO*O3 # O3
    du[4] = (-1)*(6.062E-15)*O3*NO + 1*(1.069E-11)*NO2*O + 1*(1.289E-02)*SUN*NO2 # NO
    du[5] = 1*(6.062E-15)*O3*NO - 1*(1.069E-11)*O*NO2 - 1*(1.289E-02)*SUN*NO2 # NO2
    du[6] = 0 # M (fix)
    du[7] = 0 # O2 (fix)
    nothing
end
function model_onestep()
    u0 = [6.624e+8, 9.906e+1, 5.326e+11, 8.725e+8, 2.240e+8, 8.120e+16, 1.697e+16]
    p  = t -> [t/3600.0]
    oprob = ODEProblem(strato_t!, u0, (0.0, 3600.0), p)
    @time sol = solve(oprob, dense=false, save_on=false, calck=false, save_everystep=false)
end

When solved in a single call to solve spanning from t = 0.0 to t = 3600.0, it takes

0.454232 seconds (2.65 M allocations: 147.560 MiB, 4.91% gc time)

(I believe there are allocation problems here, but irrelevant for now)

However I need to mutate the solution array at given steps, e.g. every 600.0 s. My current approach is to solve and remake in a loop, which is the prevalent approach in current Fortran-based models:

function model_chunked()
    u0 = [6.624e+8, 9.906e+1, 5.326e+11, 8.725e+8, 2.240e+8, 8.120e+16, 1.697e+16]
    p  = t -> [t/3600.0]
    oprob = ODEProblem(strato_t!, u0, (0.0, 600.0), p)
    for t in 0.0:600.0:3600.0
        @time sol = solve(oprob, dense=false, save_on=false, calck=false, save_everystep=false)
        
        u = sol.u[end]
        # ... need to do something with the solution here
        # to mimic operator splitting
        # e.g. if(t == 1200.0) u[2] += 1.5e+1
        
        oprob = remake(oprob; u0=u, tspan=(t, t+600.0))
    end
end

This allocates a lot:

  0.293247 seconds (1.82 M allocations: 101.559 MiB, 6.51% gc time)
  0.284285 seconds (1.82 M allocations: 101.539 MiB, 5.27% gc time)
  0.289665 seconds (1.82 M allocations: 101.533 MiB, 6.29% gc time)
  0.290912 seconds (1.82 M allocations: 101.529 MiB, 6.90% gc time)
  0.288586 seconds (1.82 M allocations: 101.520 MiB, 5.07% gc time)
  0.294330 seconds (1.82 M allocations: 101.514 MiB, 6.53% gc time)
  0.288314 seconds (1.82 M allocations: 101.508 MiB, 6.46% gc time)
  2.101101 seconds (12.77 M allocations: 710.813 MiB, 5.94% gc time)

Of which the (1.82M, 101.5 MiB) are what I assume to be setup allocations in the solve routine. There is also a lot of garbage collection going on, which is not good for performance.

What would be a efficient way to mutate the solution array at given “checkpoints” within the DiffEq solve? Is there a straightforward way to “reuse” the temporary allocations made by solve since the problem formulation has not really changed?

Thank you!

I have to admit I haven’t read all of your post in detail, but I think callbacks may be just what you need.

https://docs.juliadiffeq.org/latest/features/callback_functions.html

EDIT: And more specifically this: https://docs.juliadiffeq.org/latest/features/callback_library.html#PresetTimeCallback-1

Maybe. If it is, then it’s your code. Check your code: the function you’re calling is allocation free in the inner loop. The way to check this is to increase/decrease the timespan and see if the allocations change. That will tell you if you have allocations in the inner loop, and in which case they come from strato_t! so then you can double check stability there.

Why not use the integrator interface?

http://docs.juliadiffeq.org/latest/basics/integrator.html

Or use callbacks as suggested above? What you’re allocating and re-allocating is the solver cache. There are cache arrays that are required when using the solver in mutating form of course (since what else would you be mutating?), and so each solve call needs to form those once.

Also, there no need to have a big solver cache in your case because you only have 7 ODEs, so it would be faster for you to use static arrays anyways. I highly suggest that you read the “Optimizing your DiffEq code” tutorial if you care about performance.

http://tutorials.juliadiffeq.org/html/introduction/03-optimizing_diffeq_code.html

Using that with the integrator interface or callbacks is what will give you the best results and is what’s recommended.

3 Likes

Thank you klaff and Chris for the suggestions. I am testing each of the methods and it seems like the integrator interface is best suited for the problem at hand as I can essentially use step! and treat the ODE solve as a time-evolution operator. Using 7 species with StaticArrays is feasible however this is just a toy example for purposes of this question; at >100 species it proves to be inefficient.

As a follow-up question, the DiffEq documentation states in the integrator interface that

As low-level alternative to the callbacks, one can use set_t! , set_u! and set_ut! to mutate integrator states. Note that certain integrators may not have efficient ways to modify u and t . In such case, set_*! are as inefficient as reinit! .

Is there documentation on which integrators have inefficiencies with the set... mutators?

Except set_u! being essential for operator splitting, there is a potential use case for set_t!: for example it may prove to be more efficient to decouple some variables from the ODE (and use a steady-state approximation). This “reduced” ODE with less variables would be a different ODEProblem which may do a few steps here and then until the approximation does not hold, and the “full” set of ODEs kick in. This means that a discontinuous set of set_ut! jumps is induced.

Would this incur significant problems for the integrator’s internal state? (We don’t really care about the interpolation; only data at the end of each solve! call is needed)