Help optimising Bayesian ODE model with Turing

Hi all

I’m working on my first Turing model and am also very new to Julia. I have finally gotten my model to the point of running at all – that is, it failed with an error telling me to increase the number of maxiters after about 1.5 hours of sampling. I am aware of existing posts on this issue and that I can experiment with different solvers. My bigger concern at this point is that my model is much too slow for me to even try these different options. After ~1.5 hours, the sampling process was about 18% complete:

So as a first step, I need to speed up my model substantially. In fact, this is actually only a simplified version of the model I want to build – eventually I want to estimate multiple parameters, but here I am only trying to estimate a single parameter (a11):

using Parameters, LinearAlgebra, DifferentialEquations, Turing, MCMCChains, Distributions, Random, Plots, StatsPlots

"""Specify options"""

@with_kw mutable struct options
    TR::Float64
    TE::Float64
    duration::Float64
    dt::Float64
    layers::Int
    SNR::Float64
end


"""Define constants"""

@consts begin
    λ = 0.5
    ϑ₀ = 28.265*3
    r₀ = 110
    V₀ = 100*0.04
    
    κ = 0.65
    γ = 0.41
    τ = 0.98
    α = 0.32
    E₀ = 0.34
    ε = 0.4584
    τᵈ = 0.5
    pₕ = [κ γ τ α E₀ ε τᵈ]
end

"""Functions for the forward model"""

function f!(x, p, t)

    a11, a12, a21, a22, c11, c12, c21, c22, u, dt = p
    A = [a11 a21; a12 a22]
    C = [c11 c21; c12 c22]

    # Get input at time t
    idx = max(Int(ceil(t/dt)), 1)
    ut = u[idx,:]

    J = A
    J*x + C*ut
end

function h!(dz, z, p, t)
    
    H, HC, xn, dt = p

    idx = max(Int(ceil(t/dt)), 1)

    H = pₕ .* exp.(H)
    HC = [0 1; 0 0] .* exp.(HC)
    
    fv = (abs.(z[:,3])).^(1 ./ H[:,4])
    ff = ( 1 .- (1 .- H[:,5]).^(1 ./ z[:,2]) ) ./ H[:,5]
    
    # Compute state equations
    dz[:,1] = xn[idx] .- H[:,1].*z[:,1] - H[:,2].*(z[:,2] .- 1)
    dz[:,2] = z[:,1]
    dz[:,3] = (z[:,2] - fv + HC*z[:,5]) ./ H[:,3]
    dz[:,4] = (ff.*z[:,2] - fv.*z[:,4]./z[:,3] + HC*z[:,6]) ./ H[:,3]
    dz[:,5] = (-z[:,5] + z[:,3] .- 1)./(H[:,7])
    dz[:,6] = (-z[:,6] + z[:,4] .- 1)./(H[:,7])
end

function g(z, H, TE)
    
    E0 = E₀ .* exp.(H[:,5])
    εₕ = ε .* exp.(H[2,6])
    
    k₁ = 4.3.*ϑ₀.*E₀.*TE
    k₂ = εₕ.*r₀.*E₀.*TE
    k₃ = 1 - εₕ

    v = map(i -> (z[i][:,3]), 1:12000)
    v = transpose(reduce(hcat, v))
    q = map(i -> (z[i][:,4]), 1:12000)
    q = transpose(reduce(hcat, q))
    y = V₀ * (k₁'.*(1 .- q) + k₂'.*(1 .- q./v) + k₃'.*(1 .- v))
end

# Forward model
function forward_model(params, states, opt)
    a11, a12, a21, a22, c11, c12, c21, c22, u, H, HC = params
    x0, z0 = states
    
    ts = range(opt.dt, step=opt.dt, stop=opt.duration) # sample times
    tspan = (0.0, opt.duration)
    
    p = [a11, a12, a21, a22, c11, c12, c21, c22, u, opt.dt] # parameters
    ode = ODEProblem(f!, typeof.(p[1]).(x0), tspan, p)
    sol = solve(ode, AutoTsit5(Rosenbrock23()), saveat=ts, tstops=ts, dt=1e-2)
    if sol.retcode != :Success
        println(sol.retcode)
    end

    ph = [H, HC, sol.u, opt.dt] # parameters
    odeh = ODEProblem(h!, typeof.(p[1]).(z0), tspan, ph)
    solh = solve(odeh, AutoTsit5(Rosenbrock23()), saveat=ts, tstops=ts, dt = 1e-2)
    if solh.retcode != :Success
        println(solh.retcode)
    end

    y = g(solh.u, H, opt.TE)
end

"""Functions for model inversion"""

# Model inversion
Turing.setadbackend(:forwarddiff)

@model function invert_model(data, params, opt)
    
    # Unpack parameters
    a12, a21, a22, c11, c12, c21, c22, u, H, HC = params

    # Specify priors
    a11 ~ Normal(-0.5, 0.0625)
    
    # Set initial conditions
    x0 = zeros(opt.layers, 1) # initial condition
    z0 = zeros(opt.layers, 6) # initial condition
    z0[:,2:4] = exp.(z0[:,2:4]) # exponentiate haemodynamic state variables
    states = [x0, z0]

    P = [a11, a12, a21, a22, c11, c12, c21, c22, u, H, HC]
    predicted = forward_model(P, states, opt)  

    σ ~ LogNormal(6, 1/128) # Check what this should actually be

    for i = 1:size(predicted, 1)
        data[i,:] ~ MvNormal(predicted[i,:], σ)
    end
end


## 1. Simulate data
# Initialise options:
opt = options(TR = 2, TE = 0.03, duration = 600.0, dt = 0.05, layers = 2, SNR = 4.0) # keywords optional
# Generate a single impulse as input:
u = zeros(Int(opt.duration / opt.dt), 2)
u[6000:6004] .= 1
# Set parameters:
H = zeros(2,7)
HC = log(λ) * [0 1; 0 0]
parameters = [-1.1, 0, -1.1, 0, 0.3, 0, 0.3, 0, u, H, HC]
# Initial states:
x0 = zeros(opt.layers, 1)
z0 = zeros(opt.layers, 6)
z0[:,2:4] = exp.(z0[:,2:4])
states = [x0, z0]
# Run model:
Y = forward_model(parameters, states, opt) 
Yn = Y + 0.01*randn(size(Y))
plot(Yn)

## 2. Estimate parameters
# Set parameters:
p_set = [0, -1.1, 0, 0.3, 0, 0.3, 0, u, H, HC] # only estimate a11 for now
model = invert_model(Yn, p_set, opt)
chain = mapreduce(c -> sample(model, NUTS(.65), 1000), chainscat, 1:3)
# Plot
plot(chain)

I would appreciate any help and suggestions on how to adjust my model so it runs faster.

Why did it go maxiters? Did you plot it out at the parameters where it did so? Did you see if at those parameters if it was diverging? Not the parameters that you started at, but where the MCMC walked to. Do @show prob.p before the solve to see what that would be.

I apologise in advance for the very basic question… how can I plot the parameters? I see how to do this with a complete model, but not with one that fails. I am writing out the estimated parameter using @show prob.p, but I’m not getting very far with this approach. If I don’t happen to be looking at the REPL at the moment the error occurs (which has been the case on every attempt so far), I don’t see the parameter value at which it fails, because the stack is so long that it disappears.

I’ll gladly take any additional hints on debugging in Julia in general – this is something I’m finding particularly difficult to adjust to (I come from a Matlab background and am used to setting breakpoints and stepping through functions).

I’m just going to make a video on how to debug all of this. But in general, if you just print out what is going to be solved, you can always recreate the error without the loop. This will simplify things by 100000x, so please do this while debugging.