Parameter estimation of an ODE in Julia is slower than in R

I’ve been trying to estimate parameters (using max likelihood) of an ODE using the DifferentialEquations and Optim packages in Julia aiming to get a faster convergence compared to R. Problem is that I still can do it faster in R. The only advantage so far is that I get a lower value of likelihood with Julia, which is better. But I was really hoping to decrease computational time with Julia.

This is a reproducible example. I generated the data in R first and then used the exact same data in Julia. This data is pretty similar to the real data I’m really using. The code is pretty similar in both languages, so I don’t know what else I can change to make it faster in Julia. So far, I’m using the NelderMead method in both languages to make a fair comparison. I hope I can use gradient methods later and introduce automatic differentiation which can be done in Julia but not in R, but I don’t know how to do that either.

Running time for in R: 26.95 seconds
Running time in Julia: 34.76 seconds (103.59 M allocations: 22.318 GiB, 13.91% gc time, 0.04% compilation time)
Running time in Julia (with edit): 21.47seconds (64.50 M allocations: 22.087 GiB, 20.61% gc time, 0.02% compilation time)

data (if you want to download it): Link to drive with data

Julia code:

#Loading the packages
using DifferentialEquations 
using Plots
using Optim 
using DataFrames
using Distributions
using BenchmarkTools
using LinearAlgebra
using StatsBase
using CSV

data_s = DataFrame(CSV.File("Data/data_sim.csv"));

const nplot = length(unique(data_s.ID))

function f(u,p,t)
    A1    = I(50) #This is not constant but let's assume it for now
    du    = p[1].*(A1*u).*(1/t.^(1+p[2]))
    return du
end

function DEsol_l(Age,parms,n)

    u0            = parms[1:n]
    time_int      = (minimum(Age),maximum(Age))    
    
    prob          = ODEProblem(f, u0, time_int, parms[(n+1):(n+2)])
    sol           = solve(prob, saveat = unique(Age))
    sol_v         = vec(transpose(sol[:,:]))   
    return sol_v
end

function loglik1(time,var,ID,θ)
    
    nplot         = length(unique(ID))
    sigma         = exp(θ[nplot+3])
    hpred         = DEsol_l(time,θ,nplot)
    residual      = hpred .- var
    result        = sum(logpdf.(Normal(0, sigma), residual))
    
    return -result

end

inits = combine(groupby(data_s, :ID), :HDOM => minimum)
inits = [inits[!,2]; 5; 0.5; 1 ];


@time result1 = optimize(b->loglik1(data_s.AGE,data_s.HDOM,data_s.ID,b), inits,NelderMead(),Optim.Options(iterations = 1000000)) 

R code

#Data generation
set.seed(123)
nseries <- 50
AGE     <- rep(seq(5,30,3),nseries)
HDOM    <- 15*exp(-3.5/(0.9*AGE^(0.9))) + abs(rnorm(length(AGE)))*AGE*0.05
ID      <- rep(1:nseries,each=length(unique(AGE)))

data_s   <- data.frame(AGE,HDOM,ID)
setDT(data_s)
plot(AGE,HDOM)

write.csv(data_s,"data_sim.csv")

#DE function
fx2 <- function(t,h,p)
{
  p1  <- p['p1'] 
  p2  <- p['p2']
  hdot = p1*(A%*%h)*(1/t^(1+p2))
  
  return(list(c(hdot)))
}


#DE Solver 
solveDE2 <- function(times, parms,nplot)
{
  res <- rk4(y      = parms[1:nplot],
             times  = times,
             p      = parms,
             func   = fx2)
  return(list(time = rep(res[,1],nplot), Height = c(res[,-1])))
}

#Loglik function

log.lik2 <- function(params, data)
{
  model <- solveDE2(times  = unique(data$AGE),
                    parms  = params,
                    nplot  = length(unique(data$ID)))
  
  value <- sum(dnorm(model[[2]], data$HDOM, exp(params["sds"]), log = TRUE ))
  
  
  return(-value)
}

A     <- diag(nseries)
h0    <- data_s[,min(HDOM), by =  ID]
inits <- c(h0 = h0[[2]],p1=5,p2=0.5,sds=1)


system.time(
mod2      <- with(data_s,
                  optimr(par      = inits,
                         fn       = log.lik2,
                         data     = data_s,
                         method   = "Nelder-Mead",
                         control  = list(maxit=100000)))
)

mod2

Thanks!

2 Likes

The global variables here are likely a significant issue. I recommend either making these const or wrapping everything into a function.

See Performance Tips · The Julia Language

5 Likes

Thanks Mark,

I made some edits to the code and now I only have the data and the initial values as global variables. Running time went from 34.76 to 21.47 seconds in Julia. It is now faster than R, but not for much.
Do you have any other specific recommendations for optimizing this code?

I’m working through a few specific recommendations. One is to use the in-place form. There are few areas where I think you can pre-allocate memory instead of allocating new memory.

The beginning looks like this at the moment:

using DifferentialEquations
using Optim
using DataFrames
using Distributions
using LinearAlgebra
using StatsBase
using CSV

function get_inits()
    inits = combine(groupby(data_s, :ID), :HDOM => minimum)
    inits = inits[!,2]
    inits = [inits; 5; 0.5; 1];
end

function f!(du, u, p, t)
    du .= p[1].*(A1*u).*(1/t^(1+p[2]))
end

const data_s = DataFrame(CSV.File("Data/data_sim.csv"));
const AGE = data_s.AGE
const time_int      = (minimum(AGE),maximum(AGE))
const unique_AGE = unique(AGE)
const residual = similar(data_s.HDOM)
const nplot = length(unique(data_s.ID))
const A1    = I(nplot);

const inits = get_inits()
2 Likes

I’m still testing this, but I think we can also use a single integrator. This eliminates many allocations:

const prob          = ODEProblem(f!, @view(inits[1:nplot]), time_int, @view(inits[(nplot+1):(nplot+2)]))
const integrator    = init(prob, RK4(), saveat = unique_AGE)

function DEsol_l(parms)
    u0              = @view(parms[1:nplot]
    integrator.p[1] = parms[nplot+1]
    integrator.p[2] = parms[nplot+2]
    reinit!(integrator, u0)
    sol             = solve!(integrator)
    sol_v           = vec(transpose(@view(sol[:,:])))
    return sol_v
end

Edited: Revised DEsol_l to use reinit! rather than the earlier set_u! attempt.

2 Likes
const HDOM = data_s.HDOM
function loglik1(θ)

    sigma         = exp(θ[nplot+3])
    hpred         = DEsol_l(θ)
    residual     .= hpred .- HDOM
    N             = Normal(0, sigma)
    result        = sum(r -> logpdf(N,r), residual)

    return -result

end

For the logpdf we could have preallocated another array,

const logpdf_residual = similar(HDOM)

function loglik1(θ)

    sigma            = exp(θ[nplot+3])
    hpred            = DEsol_l(θ)
    residual        .= hpred .- HDOM
    N                = Normal(0, sigma)
    logpdf_residual .= logpdf.(N, residual)
    result           = sum(logpdf_residual)

    return -result
end

Since we’re just going to sum the logpdf_residual, there’s no reason to allocate that memory. So we use the form of sum that takes an anonymous function and that just sums things up along the way.

The final result looks like the following. Hopefully I did not make a mistake.

julia> @time result1 = optimize(loglik1, inits,NelderMead(),Optim.Options(iterations = 1000000))
  3.182967 seconds (7.51 M allocations: 2.534 GiB, 11.71% gc time, 3.93% compilation time)
julia> @time result1 = optimize(loglik1, inits,NelderMead(),Optim.Options(iterations = 1000000))
  2.022496 seconds (4.49 M allocations: 1.587 GiB, 12.58% gc time)

Update: I added the full example below:

Final full example
using DifferentialEquations
using Optim
using DataFrames
using Distributions
using LinearAlgebra
using StatsBase
using CSV

function get_inits()
    inits = combine(groupby(data_s, :ID), :HDOM => minimum)
    inits = inits[!,2]
    inits = [inits; 5; 0.5; 1];
end 
    
# isinplace form  
function f!(du, u, p, t)
    du .= p[1].*(A1*u).*(1/t^(1+p[2]))
end 
    
# Parse CSV and extract columns
const data_s          = DataFrame(CSV.File("Data/data_sim.csv"));
const AGE             = data_s.AGE
const HDOM            = data_s.HDOM

# Calculate AGE derived properties
const time_int        = (minimum(AGE),maximum(AGE))    
const unique_AGE      = unique(AGE)

# Preallocate constant data structures
const residual        = similar(HDOM)
const nplot           = length(unique(data_s.ID))
const A1              = I(nplot);

# Setup OrdinaryDiffEq.ODEIntegrator
const inits           = get_inits()
const prob            = ODEProblem(f!, @view(inits[1:nplot]), time_int, @view(inits[(nplot+1):(nplot+2)]))
const integrator      = init(prob, RK4(), saveat = unique_AGE)

function DEsol_l(parms)
    u0            = @view(parms[1:nplot])
    integrator.p[1] = parms[nplot+1]
    integrator.p[2] = parms[nplot+2]
    reinit!(integrator, u0) # corrected line
    sol           = solve!(integrator)
    sol_v         = vec(transpose(@view(sol[:,:])))   
    return sol_v
end

function loglik1(θ)
    sigma         = exp(θ[nplot+3])
    hpred         = DEsol_l(θ)
    residual      .= hpred .- HDOM
    N = Normal(0, sigma)
    result        = sum(r -> logpdf(N,r), residual)

    return -result
end

# first run, includes compilation
@time result1 = optimize(loglik1, inits,NelderMead(),Optim.Options(iterations = 1000000))
#  3.182967 seconds (7.51 M allocations: 2.534 GiB, 11.71% gc time, 3.93% compilation time)

# second run
@time result1 = optimize(loglik1, inits,NelderMead(),Optim.Options(iterations = 1000000))
#     2.022496 seconds (4.49 M allocations: 1.587 GiB, 12.58% gc time)

Edited: Revised to use reinit!

8 Likes

That is still allocating in the inner loop of the ODE solve because of the matmul.

I suggest looking at the following resources:

https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/

4 Likes

It seems better to use division than multiplying with (1/x). But, what is the purpose of A1 here? If u is a vector, this is a totally redundant multiplication. Could this not just be

du .= u .* (p[1]/t^(1+p[2]))

?

function f!(du, u, p, t)
    du .= p[1].*(A1*u).*(1/t^(1+p[2]))
end 

function g!(du, u, p, t)
    du .= u .* (p[1]/t^(1+p[2]))
end

A1 = I(50)
u = rand(50)
p = (3.0, 4.0)
t = 0.45
du = similar(u)
julia> g!(du, u, p, t) == f!(du, u, p, t)
true

julia> @btime f!($du, u, $p, $t) setup=(u=rand(50));
  756.552 ns (7 allocations: 656 bytes)

julia> @btime g!($du, u, $p, $t) setup=(u=rand(50));
  14.930 ns (0 allocations: 0 bytes)
3 Likes

Thanks for taking the time for this Mark.

I was looking at your code, trying to see what changes were the most effective ones. Defining the problem with the in-place form cuts time from ~25 to ~10 seconds. Removing the redundant matrix (A) saves another ~3 seconds. With that I’m down to ~7 seconds, which is 4 times faster than R.

Defining all the variables as constants, having the function for the initial values, and adding the @view in some parts, doesn’t seem to save much time. The other part that makes your code faster is adding the integrator, but that was the most confusing part for me. Can you explain a little bit what you did there? Unfortunately, I think there is a problem with that part because if you add that, it is not converging to the “right” value. The likelihood value that it reaches with this change is 2329, and we were getting 350 with the slower code.

Laura

The overall approach here is to eliminate memory allocations. Chris’s video explains why this is important. The in-place form helps with this by reusing already allocated memory. Removing the identity matrix is also helpful because it removes the memory allocating matrix multiplication.

Another issue here are the globals. For Julia these can make code harder to compile and optimize since these globals can change in type at any point of time. Thus you want to avoid variable globals if at all possible in Julia. Currently, in Julia, the easiest way to make those type stable is to make them const. const may be confusing because I appear to be mutating some of those variables. However, in those cases those “constant” structure instances may contain mutable fields. Pulling the variables out of the functions as a global consts his also helps with memory allocations by enabling the reuse of memory.

Other solutions to the global problem include using const Refs or doing everything within the local scope of a function. Since you can define functions within functions, creating closures, you could use a very similar structure to your original code.

Using the same integrator is another attempt to remove memory allocations as well as avoiding possibly unnecessary preprocessing by reusing the integrator structure on each run. The idea is that you are trying to solve a series of very similar ODE problems. Rather than reposing the problem from scratch, we reuse the old construction of the problem and just change what we need. This is meant to exploit the CommonSolve interface:

solve(args...; kwargs...) = solve!(init(args...; kwargs...))

A call to solve should essentially be an init followed by solve!. The exclamation mark is a convention that indicates an mutating function that may modify one of its arguments. In this case init returns an ODEIntegrator. From there, we use the integrator interface to try to mutate the integration problem rather than recreate it from scratch:

https://diffeq.sciml.ai/stable/basics/integrator/

Perhaps I really did need that set_t! call I commented out or perhaps it is not valid to modify the parameters using the mechanism I used. In retrospect, perhaps using reinit! might be better. Perhaps changing p is not fully valid. I will need to try to debug this later. Perhaps @ChrisRackauckas might be able to see what I did wrong there earlier than I can?

4 Likes

Before doing all of that, use

and share the flamegraph. Always look at profiles to guide what to optimize. If you post that here I’ll point to the next performance improvement.

2 Likes

Also, share sol.alg_choice from one of the solutions. Is this problem stiff or non-stiff? The integrator hasn’t been chosen, it’s just using the default, so there’s probably another big performance win just waiting from that.

1 Like

Sure. I don’t know what’s the best way to share the flamegraph. I get this:

And when I solve the equation using the parameters from the solution, I get a vector of ones in the sol.alg_choice. I’m trying to figure out what that means but I can’t find what it is.

All this is for the following code, which was improved using several of Mark’s suggestions

const nplot = length(unique(data_s.ID))
const AGE   = data_s.AGE
const HDOM  = data_s.HDOM

function f(du,u,p,t)
   du    .= p[1].*(u).*(1/t.^(1+p[2]))
    nothing
end

function DEsol_l(Age,parms)
    
    time_int      = (minimum(Age),maximum(Age))
    u0            = parms[1:nplot]
    prob          = ODEProblem(f, u0, time_int, parms[(nplot+1):(nplot+2)])
    sol           = solve(prob, saveat = unique(Age),save_everystep=false)
    sol_v         = vec(transpose(sol[:,:]))   
    return sol_v
end

function loglik1(time,var,θ)

    sigma         = exp(θ[nplot+3])
    hpred         = DEsol_l(time,θ)
    residual      = hpred .- var
    result        = sum(logpdf.(Normal(0, sigma), residual))
    
    return -result

end

function get_inits()
    inits = combine(groupby(data_s, :ID), :HDOM => minimum)
    inits = inits[!,2]
    inits = [inits; 5; 0.5; 1];
end 

const inits1           = get_inits();

@profilehtml result1 = optimize(b->loglik1(AGE,HDOM,b), inits1,NelderMead(),Optim.Options(iterations = 1000000))

1 Like

I noted that you used rk4 in the R code. I think Chris might have been interested in the following.

function DEsol_l(Age,parms,n)

    u0            = parms[1:n]
    time_int      = (minimum(Age),maximum(Age))    
    
    prob          = ODEProblem(f, u0, time_int, parms[(n+1):(n+2)])
    sol           = solve(prob, saveat = unique(Age))
    return sol
end

julia> sol = DEsol_l(data_s.AGE, inits, length(unique(data_s.ID)))

julia> sol.alg.algs[1]
Tsit5(stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false))

julia> sol.alg.algs[2]
Rosenbrock23{10, false, LUFactorization{RowMaximum}, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}(LUFactorization{RowMaximum}(RowMaximum()), OrdinaryDiffEq.DEFAULT_PRECS)

julia> sol.alg_choice
9-element Vector{Int64}:
 1
 1
 1
 1
 1
 1
 1
 1
 1
1 Like

If t is a vector (is it?) you are breaking the broadcast fusion, since (1/t.^(1+p[2])) allocates a vector.

Make sure you dot all the right places. If t is a vector:

du .= u .* p[1] ./ t.^(1+p[2])
1 Like

That means it’s a non-stiff ODE. Most of the time is spent setting up the stiff ODE solver which is never used. Change solve(prob, saveat = unique(Age)) to solve(prob, Tsit5(), saveat = unique(Age)) to hardcode a 5th order explicit RK method like is done in the R code. solve(prob, vern7(), saveat = unique(Age)) is also worth trying.

Note that one oddity here is that the R code has no error control on the solution, so :sweat_smile: it’s not an apples to apples comparison. To make it apples-to-apples, solve(prob, RK4(), adaptive=false, tstops = times), though I’d highly recommend not doing that and sticking to integration with error control of course.

3 Likes

I’ll be more careful with that.
I changed it to du.=u.*(p[1]/t^(1+p[2])), althought it really didn’t have any effect.

I understand. Adding the method helped to save time too. It takes 5 seconds now.
I think I can start using the code with real data now.

One final question now that you might be able to infer what I’m trying to do. Do you think that this can be done more efficiently if I use the DiffEqParamEstim package?

I finally just opened it up. This is the easy optimizations, just eliminating the big stuff from the profile:

using DifferentialEquations
using Plots
using Optim
using DataFrames
using Distributions
using BenchmarkTools
using LinearAlgebra
using StatsBase
using CSV

data_s = DataFrame(CSV.File("data_sim.csv"));

const nplot = length(unique(data_s.ID))
const AGE = data_s.AGE
const HDOM = data_s.HDOM
const UNIQUE_AGES = unique(AGE)

function f(du, u, p, t)
    du .= p[1] .* (u) .* (1 / t .^ (1 + p[2]))
    nothing
end

function DEsol_l(Age, parms)

    time_int = (minimum(Age), maximum(Age))
    u0 = parms[1:nplot]
    prob = ODEProblem{true}(f, u0, time_int, @view parms[(nplot+1):(nplot+2)])
    sol = solve(prob, Tsit5(), saveat=UNIQUE_AGES, save_everystep=false)
    sol_v = vec(transpose(Array(sol)))
    return sol_v
end

function loglik1(time, var, θ)

    sigma = exp(θ[nplot+3])
    hpred = DEsol_l(time, θ)
    residual = hpred .- var
    result = sum(logpdf.(Normal(0, sigma), residual))

    return -result

end

function get_inits()
    inits = combine(groupby(data_s, :ID), :HDOM => minimum)
    inits = inits[!, 2]
    inits = [inits; 5; 0.5; 1]
end

const inits1 = get_inits();

@profview result1 = optimize(b -> loglik1(AGE, HDOM, b), inits1, NelderMead(), Optim.Options(iterations=1000000))

@time result1 = optimize(b -> loglik1(AGE, HDOM, b), inits1, NelderMead(), Optim.Options(iterations=1000000))
6 Likes

And then there’s going overboard:

using DifferentialEquations
using Plots
using Optim
using DataFrames
using Distributions
using BenchmarkTools
using LinearAlgebra
using StatsBase
using CSV

data_s = DataFrame(CSV.File("data_sim.csv"));

const nplot = length(unique(data_s.ID))
const AGE = data_s.AGE
const HDOM = data_s.HDOM
const UNIQUE_AGES = unique(AGE)
const TSPAN = (minimum(AGE), maximum(AGE))
const RESIDUAL = zeros(450)

function get_inits()
    inits = combine(groupby(data_s, :ID), :HDOM => minimum)
    inits = inits[!, 2]
    inits = [inits; 5; 0.5; 1]
end

function f(du, u, p, t)
    du .= p[1] .* (u) .* (1 / t .^ (1 + p[2]))
    nothing
end

const inits1 = get_inits();

prob = ODEProblem{true}(f, inits1[1:nplot], TSPAN, @view inits1[(nplot+1):(nplot+2)])
const integ = init(prob, Tsit5(), saveat=UNIQUE_AGES, save_start=true)

function DEsol_l(parms)
    p = @view(parms[(nplot+1):(nplot+2)])
    u0 = @view(parms[1:nplot])
    copyto!(integ.p, p)
    reinit!(integ, u0; t0=TSPAN[1], tf=TSPAN[2], erase_sol=false)
    integ.saveiter = 1
    copyto!(integ.sol.u[1], u0)
    solve!(integ)
    sol_v = vec(transpose(Array(integ.sol)))
    return sol_v
end

function loglik1(θ)
    time = AGE
    var = HDOM

    sigma = exp(θ[nplot+3])
    hpred = DEsol_l(θ)
    RESIDUAL .= hpred .- var
    RESIDUAL .= logpdf.(Normal(0, sigma), RESIDUAL)
    result = sum(RESIDUAL)

    return -result

end

@profview result1 = optimize(loglik1, inits1, NelderMead(), Optim.Options(iterations=1000000))

@time result1 = optimize(loglik1, inits1, NelderMead(), Optim.Options(iterations=1000000))

This got it under a second on my computer, so I’ll just stop there. There’s still more that can be done.

6 Likes