Hey All,
I am trying to approximate the posterior distribution of a random variable inside of a differential equation with a stochastic outcome. I am able to sample from the posterior distribution, I am however unsure if the produced posterior is actually correct.
To provide some context; I am working with a repeated time-to-event model. I have written the model as a differential equation in order to simulate the bleeding rate of hemophilia patients when given specific treatment:
Essentially what happens in this model is that the bleeding hazard continuously increases the probability of an event (bleed) occuring (e.g. the survival S decreases over time). At t = 0, I sample a value between 0 and 1 (Sₑ) which represents the survival value at which an event will happen. As soon at A[5] = S(t) = Sₑ, an event occurs and a callback is triggered. This event adds 1 to e (counting the number of events), resets the cumulative hazard (A[4]) to 0 and S to 1, and samples the next value of Sₑ at which an event will occur. It is possible for the model to encounter 0 events, which happens randomly when Sₑ is small and S(t) > Sₑ at the end of follow-up.
I want to get the posterior value of the random parameter η, which in dA[4] scales the hazards to individual observed data. For example, if I know a patient had 14 bleeds and the baseline bleeding rate (λ) is 2.96, I would expect η to be larger. I’ve written the following Turing model in order to sample from the posterior:
Sampling from this model (using HMC(0.1, 5) taking 5000 samples) when y = 14 + 1 (since num bleeds starts at 1 at t0) gives me the following posterior:
I find it hard to believe that there so much density at η < 1, especially with respect to the negative values. Maybe I am not taking enough samples?
Alternatively, maybe there is an issue with getting the posterior when the model itself is stochastic in nature?
Here’s a MVP:
import DifferentialEquations: DiscreteCallback, ContinuousCallback, CallbackSet, ODEProblem, remake, solve, Tsit5
using Distributions
using Turing
function _get_rate_over_t(Iᵢ::Matrix{Float64})::Matrix{Float64}
times, _, rates, durations = eachcol(Iᵢ)
timepoints = unique(vcat(times, times + durations))
k = length(timepoints)
events = length(rates)
event_matrix = zeros(k, events)
for i in 1:events
infusion_start = findfirst(t -> t == times[i], timepoints)
infusion_end = findfirst(t -> t == times[i] + durations[i], timepoints)
if infusion_start === nothing || infusion_end === nothing
throw(Exception("Cannot find infusion start or end in the dosing timepoints for event $i"))
event_matrix[infusion_start, i] = rates[i]
event_matrix[infusion_end, i] = -rates[i]
rate_over_t = cumsum(sum(event_matrix, dims=2), dims=1)
# set all rates close to zero actually to 0. (results from inexact Float operations).
rate_over_t[isapprox.(rate_over_t, 0; atol=1e-5)] .= 0.
return hcat(timepoints, rate_over_t)
function get_weekly_callback(Iᵢ::Matrix{Float64}; S1::Float64=1.)
times_rates = _get_rate_over_t(Iᵢ) .* [1. S1]
times = times_rates[:, 1]
rates = times_rates[:, 2]
condition(u, t, p)::Bool = any(map(t_ -> isapprox(t_, t .% 168.), times))
affect!(integrator) = integrator.p[end] = rates[findfirst(t -> isapprox(t, integrator.t .% 168.), times)] # Here we assume that only a single event happens at each t, which is reasonable.
return DiscreteCallback(condition, affect!) # Setting save_positions breaks interpolation. We address this in the predict function when interpolating=true
function get_bleeding_callback()
condition(u, t, integrator) = u[5] - integrator.p[1]
function affect!(integrator)
# Set bleeding parameters
integrator.p[3] += 1 # Add one event to e
integrator.p[2] = integrator.t # Set tₑ to current time.
integrator.p[1] = rand() # set new Sₑ
# reset compartments
integrator.u[4] = 0. # Cumulative hazard
integrator.u[5] = 1. # Survival curve
return ContinuousCallback(condition, affect!; interp_points=5) # Sets save positions to true.
get_tstops(set::CallbackSet; weeks=52) = get_tstops(set.discrete_callbacks[1]; weeks=weeks)
function get_tstops(callback::DiscreteCallback; weeks=52)
times = callback.condition.times
if all(callback.condition.(0., [times; times .+ 168.], 0.))
return vcat([times .+ 168. * week for week in 0:weeks - 1]...)
return times
get_callback(Iᵢ::Matrix{Float64}; S1::Float64=1.) = CallbackSet(get_weekly_callback(Iᵢ; S1=S1), get_bleeding_callback())
function two_comp_hazard!(dA, A, p, t)
Sₑ, tₑ, e, CL, V1, Q, V2, λ, γ, IC₅₀, η, I = p
k₁₀ = CL / V1
k₁₂ = Q / V1
k₂₁ = Q / V2
dA[1] = (I / V1) + A[2] * k₂₁ - A[1] * (k₁₀ + k₁₂) # drug concentration in plasma
dA[2] = A[1] * k₁₂ - A[2] * k₂₁ # drug concentration in peripheral compartments
dA[3] = 0. # the current bleeding hazard
dA[4] = λ * exp(γ * ((t - tₑ) - 1)) * (1 - A[1] / (A[1] + IC₅₀)) * exp(η) # Cumulative hazard
dA[5] = 0. # compartment representing the current survival probability
A[3] = dA[4] # hazard
A[5] = exp(-A[4]) # survival curve
@model function poisson(y)
η ~ Normal(0., ω)
p = [rand(dist), 0., 1., 2.73, 35.8, 1.42, 6.0, λ, γ, IC₅₀, η] # Note how num bleeds = 1 at t₀ here
prob_i = remake(prob, p=vcat(p, 0.))
solve(prob_i, Tsit5(), tstops=get_tstops(callback), callback=callback)
y ~ Poisson(prob_i.p[3]) # Rejects samples that result in 0 bleeds.
hours_in_year = 24 * 365.2422 # number of hours in a year.
# Distribution where we sample P(E = 1, S = s) from.
dist = Beta(2,2)
# Parameters from Abrantes 2019:
λ = 2.96 / hours_in_year
γ = -0.566 / hours_in_year
IC₅₀ = 10.2 # 10.2 is in IU/dL
η = 0.
# prediction from McEneny-King 2019 based on 21 yo, 80kg (63 FFM), 184cm individual:
p = [rand(dist), 0., 0., 2.73, 35.8, 1.42, 6.0, λ, γ, IC₅₀, η]
prob = ODEProblem(two_comp_hazard!, [0., 0., 0., 0., 1.], (-0.1, hours_in_year), vcat(p, 0.))
ω = sqrt(log((136/100)^2 + 1))
# standard dosing treatment schedule: 3 x 1000 IU on monday, wednesday, and friday 12:00
dose = 1000
Iᵢ = [0 dose dose * 60 1/60; 48 dose dose * 60 1/60; 48*2 dose dose * 60 1/60]
callback = get_callback(Iᵢ)
ABR = 14
m = poisson(ABR + 1) # Note we add one since we are starting at bleeds = 1 not bleeds = 0.
chain = sample(m, HMC(0.1, 5), 5000)