Turing.jl: unlikely posterior with bayesian differential equation?

Hey All,

I am trying to approximate the posterior distribution of a random variable inside of a differential equation with a stochastic outcome. I am able to sample from the posterior distribution, I am however unsure if the produced posterior is actually correct.

To provide some context; I am working with a repeated time-to-event model. I have written the model as a differential equation in order to simulate the bleeding rate of hemophilia patients when given specific treatment:

function two_comp_hazard!(dA, A, p, t)
    Sₑ, tₑ, e, CL, V1, Q, V2, λ, γ, IC₅₀, η, I = p
    k₁₀ = CL / V1
    k₁₂ = Q / V1
    k₂₁ = Q / V2

    dA[1] = (I / V1) + A[2] * k₂₁ - A[1] * (k₁₀ + k₁₂) # drug concentration in plasma
    dA[2] = A[1] * k₁₂ - A[2] * k₂₁ # drug concentration in peripheral compartments
    dA[3] = 0. # the current bleeding hazard
    dA[4] = λ * exp(γ * ((t - tₑ) - 1)) * (1 - A[1] / (A[1] + IC₅₀)) * exp(η) # Cumulative hazard
    dA[5] = 0. # compartment representing the current survival probability
     A[3] = dA[4] # hazard
     A[5] = exp(-A[4]) # survival curve
end

Essentially what happens in this model is that the bleeding hazard continuously increases the probability of an event (bleed) occuring (e.g. the survival S decreases over time). At t = 0, I sample a value between 0 and 1 (Sₑ) which represents the survival value at which an event will happen. As soon at A[5] = S(t) = Sₑ, an event occurs and a callback is triggered. This event adds 1 to e (counting the number of events), resets the cumulative hazard (A[4]) to 0 and S to 1, and samples the next value of Sₑ at which an event will occur. It is possible for the model to encounter 0 events, which happens randomly when Sₑ is small and S(t) > Sₑ at the end of follow-up.

I want to get the posterior value of the random parameter η, which in dA[4] scales the hazards to individual observed data. For example, if I know a patient had 14 bleeds and the baseline bleeding rate (λ) is 2.96, I would expect η to be larger. I’ve written the following Turing model in order to sample from the posterior:

@model function poisson(y)
    η ~ Normal(0., ω)
    p = [rand(dist), 0., 1., 2.73, 35.8, 1.42, 6.0, λ, γ, IC₅₀, η] # Note how num bleeds = 1 at t₀ here
    prob_i = remake(prob, p=vcat(p, 0.))
    solve(prob_i, Tsit5(), tstops=get_tstops(callback), callback=callback)
    y ~ Poisson(prob_i.p[3]) # Rejects samples that result in 0 bleeds.
end

Sampling from this model (using HMC(0.1, 5) taking 5000 samples) when y = 14 + 1 (since num bleeds starts at 1 at t0) gives me the following posterior:

I find it hard to believe that there so much density at η < 1, especially with respect to the negative values. Maybe I am not taking enough samples?
Alternatively, maybe there is an issue with getting the posterior when the model itself is stochastic in nature?

Here’s a MVP:

import DifferentialEquations: DiscreteCallback, ContinuousCallback, CallbackSet, ODEProblem, remake, solve, Tsit5
using Distributions
using Turing

function _get_rate_over_t(Iᵢ::Matrix{Float64})::Matrix{Float64}
    times, _, rates, durations = eachcol(Iᵢ)
    timepoints = unique(vcat(times, times + durations))
    sort!(timepoints)
    k = length(timepoints)
    events = length(rates)
    event_matrix = zeros(k, events)
    for i in 1:events
        infusion_start = findfirst(t -> t == times[i], timepoints)
        infusion_end = findfirst(t -> t == times[i] + durations[i], timepoints)
        
        if infusion_start === nothing || infusion_end === nothing
            throw(Exception("Cannot find infusion start or end in the dosing timepoints for event $i"))
        end
        
        event_matrix[infusion_start, i] = rates[i]
        event_matrix[infusion_end, i] = -rates[i]
    end
    rate_over_t = cumsum(sum(event_matrix, dims=2), dims=1)
    # set all rates close to zero actually to 0. (results from inexact Float operations).
    rate_over_t[isapprox.(rate_over_t, 0; atol=1e-5)] .= 0. 
    return hcat(timepoints, rate_over_t)
end

function get_weekly_callback(Iᵢ::Matrix{Float64}; S1::Float64=1.)
    times_rates = _get_rate_over_t(Iᵢ) .* [1. S1]
    times = times_rates[:, 1]
    rates = times_rates[:, 2]
    condition(u, t, p)::Bool = any(map(t_ -> isapprox(t_, t .% 168.), times))
    affect!(integrator) = integrator.p[end] = rates[findfirst(t -> isapprox(t, integrator.t .% 168.), times)] # Here we assume that only a single event happens at each t, which is reasonable.
    return DiscreteCallback(condition, affect!) # Setting save_positions breaks interpolation. We address this in the predict function when interpolating=true
end


function get_bleeding_callback()
    condition(u, t, integrator) = u[5] - integrator.p[1]
    function affect!(integrator) 
        # Set bleeding parameters
        integrator.p[3] += 1 # Add one event to e
        integrator.p[2] = integrator.t # Set tₑ to current time.
        integrator.p[1] = rand() # set new Sₑ

        # reset compartments
        integrator.u[4] = 0. # Cumulative hazard
        integrator.u[5] = 1. # Survival curve
    end

    return ContinuousCallback(condition, affect!; interp_points=5) # Sets save positions to true.
end

get_tstops(set::CallbackSet; weeks=52) = get_tstops(set.discrete_callbacks[1]; weeks=weeks)

function get_tstops(callback::DiscreteCallback; weeks=52)
    times = callback.condition.times
    if all(callback.condition.(0., [times; times .+ 168.], 0.))
        return vcat([times .+ 168. * week for week in 0:weeks - 1]...)
    else 
        return times
    end
end


get_callback(Iᵢ::Matrix{Float64}; S1::Float64=1.) = CallbackSet(get_weekly_callback(Iᵢ; S1=S1), get_bleeding_callback())

function two_comp_hazard!(dA, A, p, t)
    Sₑ, tₑ, e, CL, V1, Q, V2, λ, γ, IC₅₀, η, I = p
    k₁₀ = CL / V1
    k₁₂ = Q / V1
    k₂₁ = Q / V2

    dA[1] = (I / V1) + A[2] * k₂₁ - A[1] * (k₁₀ + k₁₂) # drug concentration in plasma
    dA[2] = A[1] * k₁₂ - A[2] * k₂₁ # drug concentration in peripheral compartments
    dA[3] = 0. # the current bleeding hazard
    dA[4] = λ * exp(γ * ((t - tₑ) - 1)) * (1 - A[1] / (A[1] + IC₅₀)) * exp(η) # Cumulative hazard
    dA[5] = 0. # compartment representing the current survival probability
     A[3] = dA[4] # hazard
     A[5] = exp(-A[4]) # survival curve
end

@model function poisson(y)
    η ~ Normal(0., ω)
    p = [rand(dist), 0., 1., 2.73, 35.8, 1.42, 6.0, λ, γ, IC₅₀, η] # Note how num bleeds = 1 at t₀ here
    prob_i = remake(prob, p=vcat(p, 0.))
    solve(prob_i, Tsit5(), tstops=get_tstops(callback), callback=callback)
    y ~ Poisson(prob_i.p[3]) # Rejects samples that result in 0 bleeds.
end

hours_in_year = 24 * 365.2422 # number of hours in a year.
# Distribution where we sample P(E = 1, S = s) from.
dist = Beta(2,2)
# Parameters from Abrantes 2019:
λ = 2.96 / hours_in_year
γ = -0.566  / hours_in_year
IC₅₀ = 10.2 # 10.2 is in IU/dL
η = 0.
# prediction from McEneny-King 2019 based on 21 yo, 80kg (63 FFM), 184cm individual:
p = [rand(dist), 0., 0., 2.73, 35.8, 1.42, 6.0, λ, γ, IC₅₀, η]
prob = ODEProblem(two_comp_hazard!, [0., 0., 0., 0., 1.], (-0.1, hours_in_year), vcat(p, 0.))

ω = sqrt(log((136/100)^2 + 1))

# standard dosing treatment schedule: 3 x 1000 IU on monday, wednesday, and friday 12:00
dose = 1000
Iᵢ = [0 dose dose * 60 1/60; 48 dose dose * 60 1/60; 48*2 dose dose * 60 1/60]
callback = get_callback(Iᵢ)

ABR = 14
m = poisson(ABR + 1) # Note we add one since we are starting at bleeds = 1 not bleeds = 0.
chain = sample(m, HMC(0.1, 5), 5000)
1 Like

Not really: there are some things that show convergence of this in the SDE case.

It does look like the sampler is repeatedly venturing down there from the trace, so it’s not burn in or anything. I wouldn’t expect this to change much if you take more samples, though you can see what happens.

When you plot values of the chain that are in that eta < 1 region, how do the simulations look? It could be possibly finding a way to compensate other parameters to make it look plausible?

1 Like

With only one observation, I’d think the prior for \eta, which is practically a standard normal, has a sizable contribution to the posterior. Indeed if you (artificially) provide more (identical) observations

@model function poisson_multiple_solutions(y; ω, dist, p0, prob, callback)
    η ~ Normal(0., ω)
    for i in eachindex(y)
        p = [rand(dist); p0[2:end-1]; η; 0.0]
        p[3] = 1.0
        prob_i = remake(prob, p=p)
        solve(prob_i, Tsit5(), tstops=get_tstops(callback), callback=callback)
        y[i] ~ Poisson(prob_i.p[3])
    end
    return y
end

m = poisson_multiple_solutions(fill(ABR + 1, 1); ω, dist, p0=p, prob, callback)
m2 = poisson_multiple_solutions(fill(ABR + 1, 10); ω, dist, p0=p, prob, callback)

Taking 1000 samples each results in

2 Likes

Hey Chris,

Thanks for the reply.
I found that increasing the number of leapfrog steps (from 5 to 10) seems to (very) slightly ‘improve’ the posterior. Samples seem to be more concentrated around η = 1.8:

As for bleed distributions when η < 1 here are some examples:

Based on these distributions I would say the likelihood of observing 14 bleeds is pretty low…

I was indeed expecting there to be uncertainty with only a single observation, but hoped it to already be somewhat concentrated. I guess because of the large variance of the prior there is still going to be some likelihood at lower eta.

A problem in my use case is that we generally only have a single observation per patient, so I have to somehow convey the certainty in the observation. Repeating the same value multiple times is an option for that, but of course raises the question of how many times it should be repeated ;). I guess I’ll do some experimenting with that value until I find something that maybe provides a good balance?

I would put more thought into the prior. Let’s assume the “correct” eta is around 2, then the prior doesn’t put much weight into that region to begin with. Of course the idea of a true value for the parameter is untenable to begin with, unless the data was sampled from the very model you inferring.

You’ve chosen a very particular value for the variance of the prior. Is that founded in something? It should help inference to broaden the prior distribution, in order to put more weight on larger values of eta.

Hey,

Yes, the prior and all other model parameters are based on a previous study (here). There indeed is not much prior likelihood for more extreme values of η. This is because in general, the typical patient commonly depicts a low bleeding frequency. A bleeding rate of 14 is quite extreme, but something we nonetheless encounter in a rare subset of patients.

Thinking about it, this is also the problem I am facing. The a priori expectation about a new patient is that they will bleed roughly 3 times a year (hence the lambda = 2.96). However, if I am treating a patient that I know bled 14 times last year I should actually be more certain that this patient falls into the more rare patient subset.
It is very, very unlikely that a patient that historically has bled around 3 times a year to suddenly present with 14 bleeds.

In that case, it seems to me the prior is at odds with your own observations. If I take it correctly, the patients from the study did not show such large bleeding rates. Then again, a study is only a limited sample. Had they included more patients, chances are, you’d see a much broader distribution for \eta.

Maybe patients cluster into low- and high-frequency bleeders and you could use a mixture model, i.e. different priors for each category?

1 Like

I think you are right in saying that the prior should likely be a mixed distribution rather than a single distribution. Since in most of these studies patients bleed <3 times, we probably get too much shrinkage due to the prior. We’ll likely end up underestimating the number of bleeds for exactly the patient group we are most interested in.

Thanks!

1 Like

For others revisiting this later, instead of passing multiple observations to the model, I used a power posterior to approximate the same behaviour.

Currently working on creating new priors as a more reliable solution to the problem.