Slow performance when solving SDE for a large number of functions in a loop

I have the following use case:
I have a large dataset of 1 million drift and diffusion functions (stored as strings of the symbolic expressions) for which I want to solve the corresponding SDEs.

This is my current approach:

using DifferentialEquations

function solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt; method=EM())
    sdes = []
    sols = []
    for i in 1:length(drift_functions)
        f = string_to_function(drift_functions[i])
        g = string_to_function(diffusion_functions[i])
        prob = SDEProblem(f, g, init_conditions[i], tspan)
        ensembleprob = EnsembleProblem(prob)
        @time sol = Base.invokelatest(solve, ensembleprob, method, EnsembleThreads(); dt=dt, saveat=dt, trajectories=num_paths)
        push!(sdes, prob)
        push!(sols, sol)
    end
    return sdes, sols
end

function string_to_function(f_strs)
    # Parse each expression
    exprs = Meta.parse.(f_strs)
    
    # Generate the function dynamically
    f! = eval(quote
        function f!(du, u, p, t)
            $(Expr(:block, [:(du[$i] = $(exprs[i])) for i in 1:length(exprs)]...))
        end
        f!
    end)
    return f!
end

function main()
    drift_functions = [["0.02*t - 0.1*u[1]"], ["u[1] - u[2] - u[1]*u[2]^2 - u[1]^3", "u[1] + u[2] - u[1]^2*u[2] - u[2]^3"]]
    diffusion_functions = [["0.4"], ["sqrt(1+u[2]^2)", "sqrt(1+u[1]^2)"]]
    init_conditions = [[1.0, 1.0], [1.0, 1.0]]
    tspan = (0.0, 1.0)
    num_paths = 100
    dt = 0.01
    sdes, sols = solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt)
end

main()

The problem that I am now facing is that the solving is very slow because it has to recompile at every step:

2.126470 seconds (3.10 M allocations: 203.331 MiB, 4.39% gc time, 99.86% compilation time)
1.103900 seconds (782.84 k allocations: 48.191 MiB, 6.88% gc time, 99.73% compilation time: 18% of which was recompilation)

Is there a way to speed this up?

My first thought was that this is caused by the usage of dynamically compiled functions and the resulting usage of Base.invokelatest. The code with pre-compiled functions however runs equally slow:

using DifferentialEquations

function solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt; method=EM())
    sdes = []
    sols = []
    for i in 1:length(drift_functions)
        f = drift_functions[i]
        g = diffusion_functions[i]
        prob = SDEProblem(f, g, init_conditions[i], tspan)
        ensembleprob = EnsembleProblem(prob)
        @time sol = solve(ensembleprob, method, EnsembleThreads(); dt=dt, saveat=dt, trajectories=num_paths)
        push!(sdes, prob)
        push!(sols, sol)
    end
    return sdes, sols
end

function f1(du, u, p, t)
    du[1] = 0.02*t - 0.1*u[1]
end

function f2(du, u, p, t)
    du[1] = u[1] - u[2] - u[1]*u[2]^2 - u[1]^3
    du[2] = u[1] + u[2] - u[1]^2*u[2] - u[2]^3
end

function g1(du, u, p, t)
    du[1] = 0.4
end

function g2(du, u, p, t)
    du[1] = sqrt(1+u[2]^2)
    du[2] = sqrt(1+u[1]^2)
end

function main()
    drift_functions = [f1, f2]
    diffusion_functions = [g1, g2]
    init_conditions = [[1.0, 1.0], [1.0, 1.0]]
    tspan = (0.0, 1.0)
    num_paths = 100
    dt = 0.01
    sdes, sols = solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt)
end

main()
2.214550 seconds (3.11 M allocations: 203.411 MiB, 8.17% gc time, 99.84% compilation time)
2.070660 seconds (1.52 M allocations: 96.149 MiB, 33.97% gc time, 99.85% compilation time)

Thanks in advance!

Reducing dynamism always helps. Can you write this as instead a single parameterized function?

There are many ways to deal with this, each with advantages and disadvantages.

One of them is the following: Wrap the function in a nonspecialized container:

 struct UnspecFunc
     f::Function
 end
 (f::UnspecFunction)(args...) = f.f(args...)
  ...
  f = UnspecFunction(drift_functions[i])

That way you won’t specialize all of the DiffEq machinery on the specific function. I did this trick in e.g. Avoid specializing all of ForwardDiff on every equation by KristofferC · Pull Request #37 · bankofcanada/ModelBaseEcon.jl · GitHub.

Before:

julia> main();
  0.450631 seconds (775.57 k allocations: 48.133 MiB, 2.08% gc time, 99.65% compilation time: 20% of which was recompilation)
  0.442212 seconds (781.97 k allocations: 48.473 MiB, 1.82% gc time, 99.67% compilation time: 17% of which was recompilation)

julia> main();
  0.455274 seconds (775.57 k allocations: 48.133 MiB, 3.92% gc time, 98.50% compilation time: 20% of which was recompilation)
  0.444641 seconds (781.97 k allocations: 48.473 MiB, 0.92% gc time, 99.66% compilation time: 15% of which was recompilation)

After:

julia> main();
  0.479668 seconds (838.09 k allocations: 49.240 MiB, 99.31% compilation time: 20% of which was recompilation)
  0.005831 seconds (77.95 k allocations: 2.536 MiB, 48.87% compilation time)

julia> main();
  0.012806 seconds (76.53 k allocations: 2.449 MiB, 52.36% compilation time)
  0.011077 seconds (77.95 k allocations: 2.536 MiB, 53.36% compilation time)

There are alternatives such as:

  • If there are common structures among the equations you can parameterize them in a different way (similar to how a polynomial can be parameterized by its coefficients). That way you only compile one function that gets different input data for each concrete function.
1 Like

I remember that the two of you already helped me six years ago when I started with julia so it is great to see that you are still around helping people! :hearts:

I used Chris’ approach to parameterize the function and giving it the index:

using DifferentialEquations

function solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt; method=EM())
    sdes = []
    sols = []
    for i in 1:length(drift_functions)
        p = (i, drift_functions, diffusion_functions)
        prob = SDEProblem(f, g, init_conditions[i], tspan, p)
        ensembleprob = EnsembleProblem(prob)
        @time sol = Base.invokelatest(solve, ensembleprob, method, EnsembleThreads(); dt=dt, saveat=dt, trajectories=num_paths)
        push!(sdes, prob)
        push!(sols, sol)
    end
    return sdes, sols
end

function string_to_function(f_strs)
    # Parse each expression
    exprs = Meta.parse.(f_strs)
    
    # Generate the function dynamically
    f! = eval(quote
        function f!(du, u, p, t)
            $(Expr(:block, [:(du[$i] = $(exprs[i])) for i in 1:length(exprs)]...))
        end
        f!
    end)
    return f!
end

function f(du, u, p, t)
    idx = p[1]
    functions = p[2]
    return functions[idx](du, u, p, t)
end

function g(du, u, p, t)
    idx = p[1]
    functions = p[3]
    return functions[idx](du, u, p, t)
end

function main()
    drift_functions_str = [["0.02*t - 0.1*u[1]"], ["u[1] - u[2] - u[1]*u[2]^2 - u[1]^3", "u[1] + u[2] - u[1]^2*u[2] - u[2]^3"]]
    diffusion_functions_str = [["0.4"], ["sqrt(1+u[2]^2)", "sqrt(1+u[1]^2)"]]
    drift_functions = [string_to_function(drift_function_str) for drift_function_str in drift_functions_str]
    diffusion_functions = [string_to_function(diffusion_function_str) for diffusion_function_str in diffusion_functions_str]
    init_conditions = [[1.0, 1.0], [1.0, 1.0]]
    tspan = (0.0, 1.0)
    num_paths = 100
    dt = 0.01
    sdes, sols = solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt)
end

main()

This takes

julia> main();
 2.316175 seconds (3.22 M allocations: 210.819 MiB, 5.46% gc time, 99.86% compilation time)
 0.002832 seconds (16.03 k allocations: 1.499 MiB, 0.01% compilation time)

julia> main();
 2.241963 seconds (3.22 M allocations: 210.819 MiB, 4.17% gc time, 99.84% compilation time)
 0.003162 seconds (16.03 k allocations: 1.499 MiB, 0.01% compilation time)

on my machine so it is even faster than Kristoffer’s approach which takes for me:

julia> main();
  2.197187 seconds (3.19 M allocations: 206.386 MiB, 3.91% gc time, 99.79% compilation time)
  0.010793 seconds (77.95 k allocations: 2.536 MiB, 57.63% compilation time)

julia> main();
  2.175405 seconds (3.19 M allocations: 206.386 MiB, 3.72% gc time, 99.79% compilation time)
  0.010782 seconds (77.95 k allocations: 2.536 MiB, 57.43% compilation time)

The performance increase is probably that the first function doesn’t need any recompilation.

Thanks for your help guys! :slight_smile: