Help speeding up `DiscreteProblem`

Hi Everyone,

I’m finding that using DiscreteProblem on my toy example is about 4x slower than a vanilla loop with pre-allocated output.

DiscreteProblem version
Base Julia version

My model is as follows:

function sir_map!(du,u,p,t)
    (S,I,R) = u
    (β,c,γ,δt) = p
    N = S+I+R
    infection = rate_to_proportion(β*c*I/N,δt)*S
    recovery = rate_to_proportion(γ,δt)*I
    @inbounds begin
        du[1] = S-infection
        du[2] = I+infection-recovery
        du[3] = R+recovery
    end
    nothing
end;

And my base code loop for solving it is as follows:

function solve_map(f, u0, nsteps, p)
    # Pre-allocate array with correct type
    sol = similar(u0, length(u0), nsteps + 1)
    # Initialize the first column with the initial state
    sol[:, 1] = u0
    # Iterate over the time steps
    @inbounds for t in 2:nsteps+1
        u = @view sol[:, t-1] # Get the current state
        du = @view sol[:, t] # Prepare the next state
        f(du, u, p, t)       # Call the function to update du
    end
    return sol
end;

I solve the above using DiscreteProblem as follows.

prob_map = DiscreteProblem(sir_map!,u0,tspan,p)
sol_map = solve(prob_map,FunctionMap());

I suspect that the time difference between the implementations is due to not pre-allocating the output for the solution of DiscreteProblem; any suggestions on how to speed it up?

Try SimpleDiffEq.jl SimpleFunctionMap. Using the full ODE integrator with all of its other error checks can get in the way if there’s no continuous element.

Thanks @ChrisRackauckas! Using SimpleFunctionMap gets me nearly all the way there:

julia> @benchmark solve(prob_map,FunctionMap())
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  38.250 μs …  14.435 ms  ┊ GC (min … max): 0.00% … 99.47%
 Time  (median):     42.917 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   46.853 μs ± 192.679 μs  ┊ GC (mean ± σ):  6.77% ±  1.72%

    ▅▇▇▄▇▄▄▃▆█▆▆▆▆▃▅▂▂▁                                         
  ▂▅█████████████████████▆▇▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▄
  38.2 μs         Histogram: frequency by time         59.4 μs <
julia> @benchmark solve(prob_map,SimpleFunctionMap())
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  14.916 μs …  15.643 ms  ┊ GC (min … max): 0.00% … 99.67%
 Time  (median):     16.709 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   18.751 μs ± 156.266 μs  ┊ GC (mean ± σ):  8.31% ±  1.00%

       ▅█▇▂▂▁  ▁                                                
  ▁▂▃█████████▆███▇▅▆▅▅▄▅▄▄▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  14.9 μs         Histogram: frequency by time         24.3 μs <
julia> @benchmark solve_map(sir_map!, u0, nsteps, p)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  11.625 μs … 47.708 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     12.500 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.690 μs ±  1.352 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▄█   █  ▃▇▇                                                  
  ██▇▄██▆▅███▇▆▆▆▅▆▄▄▄▄▆▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  11.6 μs         Histogram: frequency by time          17 μs <

Open an issue for the last bit. We should optimize that @Oscar_Smith

1 Like

In SimpleDiffEq.jl?

yes please