How to sum outputs from DifferentialEquations.jl

I am making a series of compartmental ODE infection models using DifferentialEquations. I would like to be able to plot the sum of two or more compartments (e.g. I might have people with undiagnosed infection and diagnosed infection, and want to plot the total number with infection).

A simple example is that I’d like to sum the entire population each time I change the model to check that I haven’t made a mistake that causes the population size to change.

I’ve put a simple example below that is very unattractive. I can work at changing saveat to make the plot smoother, but I’m hoping there’s something within DifferentialEquations that I’ve missed that will perform this more elegantly. Thanks.

using DifferentialEquations, Plots, StaticArrays

function sir(u, p, t)
    S, I, R = u 
    β, γ = p
    N = S + I + R 

    dS = -β * S * I / N 
    dI = β * S * I / N - γ * I 
    dR = γ * I

    return @SVector [dS, dI, dR]

u₀ = @SVector [.999, .001, 0]
p = [.3, .1]
tspan = (0., 100.)

prob = ODEProblem(sir, u₀, tspan, p)
sol = solve(prob)

function plot_total_size(solution)
    y::Array = sum.(solution.u')'
    x = solution.t
    plot(x, y)





Thank you. I think that’s summing within the compartments across all time but sum(sol,dims=1) looks like what I want – across compartments at each time point.

It still only plots at the discrete time points shown, but I think that if I make saveat small then it will look smooth.

You can also add a function to the plotter


plot(sum(sol, dims = 1)) doesn’t work for me, but so long as I don’t want all the compartments summed, I get good results with plot(sol.t, (@. sol[2,:] + sol[3,:])), which is, in any case, more like the plots I will want to produce. Thank you

Look at the documentation:

i.e. something like plot(sol, vars=((t,s1,s2) -> (t,s1+s2), 0,1,2)), I guess?


Thank you both - that works perfectly.