Speeding up repetitive calls to ODEProblem

Hey ODE people. Thanks for all the mind-bending great work on top-of-the-line packages. I have stupid code solving the same ODE with different boundary conditions thousands of times. Unfortunately for this use case, this comes out to just one likelihood evaluation, so it’d be nice to be fast. :sweat_smile:

So I took a look at the code profile, and it is painted with type instability inside ODEProblem


I’m wondering if I can feed ODEProblem something more digestible, maybe it needs annotations? Here’s the function that I’m calling so many times.

function construct_transitions(model, sol, start_time, end_time, 
    start_state, end_state)
    # returns the transition matrix element from start_time to end_time
    # constructed using the time-dependent potential given by sol
    u0::Vector{Float64} = zeros(model.nstates)
    u0[start_state] += 1
    tspan  = (start_time, end_time)
    function mutgrad!(du,u,model,t)
        ψ::Vector{Float64} = sol(t)
        for j in 1:model.nstates
            du[j] = 0.0
            for i in 1:model.nstates
                if i != j
                    du[j] += u[i] * mutation_matrix(model, t, i, j) * exp(ψ[j] - ψ[i]) - 
                        u[j] * mutation_matrix(model, t, j, i) * exp(ψ[i] - ψ[j])
                    # a most intuitive representation of the probability
                end
            end
        end
    end
    prob = ODEProblem(mutgrad!, u0, tspan, model)
    psol = solve(prob,Tsit5(); tstops = model.tstops, save_everystep = false, save_start = false)
    return psol[1][end_state]
end

possibly relevant information sol is itself an ODE solution. mutation_matrix is already annotated with a ::Float64 for the return value. (This is just integrating the Kolmolgorov equation with time dependent propagator.) If people are interested I can try to give enough to make a MWE but hoping that preexisting expertise will be enough. So grateful to Julian community.:raised_hands:

1 Like

You should try to avoid creating a new ODEProblem on every call.

Probably remake is what you are searching for. For example:


using DifferentialEquations

function multiple_odes()
    f(u,p,t) = 1.01*u
    u0 = 1/2
    tspan = (0.0,1.0)
    prob = ODEProblem(f,u0,tspan)
    for i in 1:100
        prob = remake(prob, u0 = rand())
        sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    end
end

function multiple_odes0()
    f(u,p,t) = 1.01*u
    u0 = 1/2
    tspan = (0.0,1.0)
    for i in 1:100
        u0 = rand()
        prob = ODEProblem(f,u0,tspan)
        sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    end
end

Gives:

julia> @btime multiple_odes()
  268.775 μs (4104 allocations: 621.91 KiB)

julia> @btime multiple_odes0()
  1.019 ms (14057 allocations: 1.04 MiB)
3 Likes

Thanks for the tip @lmiq! remake looks very promising. I’ll have to refactor a little to share the ODEProblem instance between solves and will report back.

Hi, sorry to hijack the thread but I am looking at a similar problem - do you know if remake is thread-safe? what I mean is, would it be possible to do

 @threads  for i in 1:100
        new_prob = remake(prob, u0 = rand())
        sol = solve(new_prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    end

Its not clear to me in what way remake affects the original problem… I am fairly new to julia so please excuse my ignorance! :slight_smile:

I think it is not (I may be wrong). You should use in this case a pool of problems, something like:

my_probs = [ deepcopy(prob) for _ in 1:nthreads() ]
for ithread in 1:nthreads()
    for i in 1:repeated_runs
        prob = remake(my_probs[ithread], ....)
        solve(prob)
    end
end

@hbooth I edited the code here, the remake should be inside the inner loop, maybe that caused the confusion.

3 Likes

I was thinking about it from the perspective of using multithreading for the repeated runs i.e. use each available thread to evaluate each parameter choice. In this case, the code

repeated_runs = nthreads()
my_probs = [ deepcopy(prob) for _ in 1:repeated_runs]

@threads for i in 1:repeated_runs
    new_prob = remake(my_probs[i], u0 = rand())
end

would surely just be equivalent in performance to the original

@threads for i in 1:repeated_runs
    u0 = rand()
    new_prob = ODEProblem(f,u0,tspan)
end

Does that make sense?

1 Like

No, not really. Building the problem is expensive (i. e. prob = ODEProblem(f,u0,tspan)) relative to remakeing it. You should use remake also to provide a different initial point for the problem, and if prob is not thread safe (I think it isn’t), the best to do is to create a different one for each thread.

(ps: I’m not a heavy user of these tools, so it may well be possible that someone else has better ideas about all this stuff).

3 Likes

Precisely.

Your original problem is also that inference is better if you do ODEProblem{true}(f,u0,tspan), but I’m just going to see if I can fix that this week with Tricks.jl

2 Likes

Just to clarify - remake(prob;..) is not an in-place function since prob is immutable, correct? In which case, shouldn’t it be assigned to some variable in the example codes given in this thread, i.e.

    new_prob = remake(prob;..)
    solve(new_prob,...)

rather than

    remake(prob;..)
    solve(prob,...)
1 Like

yes.

1 Like

Thanks for pointing that out. I edited my posts to avoid future errors.

1 Like

Finally got around to refactoring with remake. Here’s the results

julia> @time evaluate_loglikelihood(lt, seasonal_model(ne_shared, 2.0*10^-3, 0.0))
  0.497888 seconds (1.41 M allocations: 90.216 MiB)
-79.5242729458968
julia> @time reevaluate_loglikelihood(lt, seasonal_model(ne_shared, 2.0*10^-3, 0.0))
  0.106464 seconds (416.61 k allocations: 36.014 MiB)
-79.5242729458968

profile looks good. Marking this as solved. Thanks for the help everyone!

2 Likes