Slow hierarchical ODE model with Turing

Hello,

I’m having a great deal of trouble with the current model I’m working on. I’ve talked about it in slack but still struggling so apologies for cross posting.

It’s an hierarchical model with 30 subjects, each of whom have between 3-5 time points with 66 data points per time. The model I have is given below:

function NetworkFKPP(u, p, t)
    κ, α = p 
    du = -κ * L * u .+ α .* u .* (1 .- u)
end

@model function hierarchical_FKPP_NCP(data, initial_conditions, time, scans, prob)

    n = size(data)[2]

    σ ~ InverseGamma(2,3)

    κₘ ~ truncated(Normal(), 0, Inf)
    κₛ ~ truncated(Normal(), 0, Inf)

    αₘ ~ Normal()
    αₛ ~ truncated(Normal(), 0, Inf)

    κ ~ filldist(truncated(Normal(), 0, Inf), n)
    α ~ filldist(Normal(), n)
    
    k = (κ .* κₛ) .+ κₘ 
    a = (α .* αₛ) .+ αₘ

    @threads for i in 1:n
        prob_n = remake(prob, u0 = initial_conditions[:,i], p = [k[i], a[i]])
        predicted = solve(prob_n, Tsit5(), saveat=time[1:scans[i],i])
        cortical_preds = predicted[cortical_nodes,:]
        Turing.@addlogprob! loglikelihood(MvNormal(vec(cortical_preds), σ), data[1:66*scans[i],i])
    end
end

model = hierarchical_FKPP_NCP(data, initial_conditions, time, pos_scans, problem)
posterior = sample(model, NUTS(.65), 1_000)

data is a 330 x 30 array; initial conditions is a 83 x 78 array; time is a 5 x 30 array, scans is a 30 x 1 vector and problem is an ODEProblem set up with dummy variables. The ODE is a FKPP (diffusion plus logistic growth) model on a 83x83 undirected graph.

Trying to get 1_000 samples, the ETA seems to blow up to be in the order of days. I’ve tried to make as many optimisations as I can using the docs and posts on slack/discourse/github etc but it’s still taking a very long time. I’ve also tried with reverse diff but this doesn’t seem to help. I’m trying to reproduce a collaborators results – they have used pymc3 on a similar model which they ran in 1.5-2 hrs on their laptop. The main differences between mine and their model are that they’re using explicit Euler for the ODE integration, they have 78 subjects, are using a centred parameterisation and have half Cauchy priors on their noise and hierarchical s.d priors. I can’t think of why any of these differences would have a significant impact on the Turing model (apart from maybe taking longer with more subjects). I’ve tested using Euler integration with dt=0.1 (as my collaborator had set) but this causes instability in the model and it crashes. Using Euler with a smaller time step has allocations than using Tsit5.

I’m running Julia 1.6 on a linux machine and using 32 threads. Env details:

      Status `~/Projects/TauPet/Project.toml`
  [6e4b80f9] BenchmarkTools v1.0.0
  [41bf760c] DiffEqSensitivity v6.48.0
  [0c46a032] DifferentialEquations v6.17.1
  [31c24e10] Distributions v0.25.2
  [093fc24a] LightGraphs v1.3.5
  [23992714] MAT v0.10.1
  [c7f686f2] MCMCChains v4.12.0
  [6fafb56a] Memoization v0.1.11
  [91a5bcdd] Plots v1.15.3
  [c3e4b0f8] Pluto v0.14.7
  [7f904dfe] PlutoUI v0.7.9
  [37e2e3b7] ReverseDiff v1.9.0
  [47aef6b3] SimpleWeightedGraphs v1.1.1
  [f3b207a7] StatsPlots v0.14.21
  [fce5fe82] Turing v0.16.0
  [e88e6eb3] Zygote v0.6.12

Any help would be greatly appreciated.

Thanks,
Pavan

The way you have implemented your differential equation is the slow way. You’ll want to use in-place functions. See https://tutorials.sciml.ai/html/introduction/03-optimizing_diffeq_code.html for some help getting started.

1 Like

Thanks! I got a big speed up in the ode solve using an in-place function. Also changing my L matrix from a sparse array to dense array also makes it faster. I tried static vectors and broadcast fusion using @. but these benchmarked slower than the in-place function. So I ended up with just this:

function NetworkFKPP4!(du, u, p, t) 
    κ, α = p 
    du = -κ * L * u .+ α .* u .* (1 .- u)
end

I’ll see how this impacts inference speed now!

That’s not correct, you’ll notice your derivatives are zero like that. It should be du .= -κ * L * u .+ α .* u .* (1 .- u), the .= for mutation.

Thank you for the correction, you’re absolutely right – my bad!