I have the following use case:
I have a large dataset of 1 million drift and diffusion functions (stored as strings of the symbolic expressions) for which I want to solve the corresponding SDEs.
This is my current approach:
using DifferentialEquations
function solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt; method=EM())
sdes = []
sols = []
for i in 1:length(drift_functions)
f = string_to_function(drift_functions[i])
g = string_to_function(diffusion_functions[i])
prob = SDEProblem(f, g, init_conditions[i], tspan)
ensembleprob = EnsembleProblem(prob)
@time sol = Base.invokelatest(solve, ensembleprob, method, EnsembleThreads(); dt=dt, saveat=dt, trajectories=num_paths)
push!(sdes, prob)
push!(sols, sol)
end
return sdes, sols
end
function string_to_function(f_strs)
# Parse each expression
exprs = Meta.parse.(f_strs)
# Generate the function dynamically
f! = eval(quote
function f!(du, u, p, t)
$(Expr(:block, [:(du[$i] = $(exprs[i])) for i in 1:length(exprs)]...))
end
f!
end)
return f!
end
function main()
drift_functions = [["0.02*t - 0.1*u[1]"], ["u[1] - u[2] - u[1]*u[2]^2 - u[1]^3", "u[1] + u[2] - u[1]^2*u[2] - u[2]^3"]]
diffusion_functions = [["0.4"], ["sqrt(1+u[2]^2)", "sqrt(1+u[1]^2)"]]
init_conditions = [[1.0, 1.0], [1.0, 1.0]]
tspan = (0.0, 1.0)
num_paths = 100
dt = 0.01
sdes, sols = solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt)
end
main()
The problem that I am now facing is that the solving is very slow because it has to recompile at every step:
2.126470 seconds (3.10 M allocations: 203.331 MiB, 4.39% gc time, 99.86% compilation time)
1.103900 seconds (782.84 k allocations: 48.191 MiB, 6.88% gc time, 99.73% compilation time: 18% of which was recompilation)
Is there a way to speed this up?
My first thought was that this is caused by the usage of dynamically compiled functions and the resulting usage of Base.invokelatest
. The code with pre-compiled functions however runs equally slow:
using DifferentialEquations
function solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt; method=EM())
sdes = []
sols = []
for i in 1:length(drift_functions)
f = drift_functions[i]
g = diffusion_functions[i]
prob = SDEProblem(f, g, init_conditions[i], tspan)
ensembleprob = EnsembleProblem(prob)
@time sol = solve(ensembleprob, method, EnsembleThreads(); dt=dt, saveat=dt, trajectories=num_paths)
push!(sdes, prob)
push!(sols, sol)
end
return sdes, sols
end
function f1(du, u, p, t)
du[1] = 0.02*t - 0.1*u[1]
end
function f2(du, u, p, t)
du[1] = u[1] - u[2] - u[1]*u[2]^2 - u[1]^3
du[2] = u[1] + u[2] - u[1]^2*u[2] - u[2]^3
end
function g1(du, u, p, t)
du[1] = 0.4
end
function g2(du, u, p, t)
du[1] = sqrt(1+u[2]^2)
du[2] = sqrt(1+u[1]^2)
end
function main()
drift_functions = [f1, f2]
diffusion_functions = [g1, g2]
init_conditions = [[1.0, 1.0], [1.0, 1.0]]
tspan = (0.0, 1.0)
num_paths = 100
dt = 0.01
sdes, sols = solve_sdes(drift_functions, diffusion_functions, init_conditions, tspan, num_paths, dt)
end
main()
2.214550 seconds (3.11 M allocations: 203.411 MiB, 8.17% gc time, 99.84% compilation time)
2.070660 seconds (1.52 M allocations: 96.149 MiB, 33.97% gc time, 99.85% compilation time)
Thanks in advance!