(Edit: here is copy-paste-able code for the problematic model: heartbeat_question_code.jl · GitHub)
I’m trying to learn Turing.jl, and after a few simple models, I wanted to fit a more complex model, that I could not fit in R’s brms.
Seismocardiography is the study of the heart via chest vibrations. It is usually done with highfrequency accelerometers attached to the chest.
I’m trying to look at this effect with the relatively low-frequency accelerometer in my phone (50 Hz).
I may write op a blog post about it, if I (we) get it working.
Data is recorded by simply placing a phone on my chest while laying on my back.
The idea is to fit a spline to data recorded over multiple heartbeats.
The problem is that I do not have information about when each heart beat starts.
Data looks like this:
using Turing
using DataFrames
using CSV
using Tidier
using Plots
using StatsPlots
sample_data_path = "https://gist.githubusercontent.com/JohannesNE/df539760195619c642b0ea43b07c6ca3/raw/bb72096912b07b5221bd4a2202eac0ec7c9121a4/acc_apnea.csv"
acc = DataFrame(CSV.File(download(sample_data_path)))
@df acc plot(:sec, [:x :y :z],
title = ["X" "Y" "Z"],
layout=(3, 1),
legend = false)
Most of the variation is in the Z axis (the normal axis the the phone screen, and hence the chest), so we’ll just use that axis for the model.
To make sampling faster, we start with just 10 seconds
acc10s = @filter(acc, sec < 10)
Y10s = @pull(acc10s, z)
X10s = @pull(acc10s, sec)
plot(X10s, Y10s)
Fitting a spline for just 1 beat
We can fit a spline to a single beat like this:
using BSplines
acc1s = @filter(acc, sec < 1)
Y1s = @pull(acc1s, z)
X1s = @pull(acc1s, sec)
num_knots = 30
# We can probably place these knots more densely around 0.1 sec into the beat,
# where most of the variation is. For now we do it uniformly.
knots_list = range(0, 1; length=num_knots)
basis = BSplineBasis(3, knots_list)
@model function simple_spline_regression(x, y)
# Set variance prior.
σ² ~ truncated(Normal(0, 1); lower=0)
# Set intercept prior.
intercept ~ Normal(0, sqrt(3))
# Set spline prior
w ~ MvNormal(zeros(length(basis)), 1)
s = Spline(basis, w)
# Calculate all the mu terms.
mu = intercept .+ s.(x)
# Likelihood
for i in eachindex(y)
y[i] ~ Normal(mu[i], σ²)
end
return (mu = mu, pred = y)
end
simple_spline_model = simple_spline_regression(X1s, Y1s)
chains_simple_spline_model = sample(simple_spline_model, NUTS(), 100)
To visualize the model fit, we sample mu
from the posterior.
# Generate predictions of mu with a higher sample rate to get smooth curves.
X1s_pred = range(0, 1, 200)
simple_spline_model_post = simple_spline_regression(X1s_pred, Vector{Union{Missing, Float64}}(undef, length(X1s_pred)))
samples_simple_spline_model_post = generated_quantities(simple_spline_model_post, chains_simple_spline_model)
# generated_quantities returns an array of tuples. This can be unpacked to a matrix
function gq2matrix(gq, selector)
reduce(vcat, getindex(samp, selector)' for samp in gq[:,1])
end
samples_simple_spline_model_post_mu = gq2matrix(samples_simple_spline_model_post, :mu)
function plot_post_pred!(pred_draws::Array, x, n = 100)
plot()
for i in 1:n
plot!(x, pred_draws[i,:], legend=false, alpha = 0.2)
end
epred_mean = mean.(eachcol(pred_draws))
plot!(x, epred_mean, lw = 3, color = :blue)
current()
end
plot_post_pred!(samples_simple_spline_model_post_mu, X1s_pred)
plot!(X1s, Y1s, lw = 1, color = :black)
Now, the goal is to fit a spline like the one above, but to multiple beats.
To do this, we need to learn when each beat starts—Except for the first, which we define as starting at sec=0.
It is probably reasonable to draw the intervals between beats from a log-normal distribution, e.g.:
plot(LogNormal(log(1), 0.05))
To get the time for each beat we could do cumsum
of draws from this distribution, but this makes the position of beats dependent on all previous beats, which is probably not ideal. Is there a better way to sample these positions.
vline(cumsum(rand(LogNormal(log(1), 0.05), 10)))
After we get the timing of each beat, we use the following function to calculate time since latest beat, which we can model a spline over.
function get_time_since_event(time_vec, event_vec)
ann_index = zeros(eltype(event_vec), length(time_vec))
i_ann::Int64 = 1
for i in eachindex(time_vec)
if i_ann < length(event_vec) && time_vec[i] > event_vec[i_ann + 1]
i_ann += 1
end
ann_index[i] = time_vec[i] - event_vec[i_ann]
end
ann_index
end
We add a second spline over the entire 10 sec signal to catch any slower trend, though this may not be necessary for this, quite stationary, signal.
Y10s = @pull(acc10s, z)
X10s = @pull(acc10s, sec)
# Spline setup
num_knots_beat = 30
num_knots_trend = 20
# knots_list = quantile(X1s, range(0, 1; length=num_knots))
knots_list_beat = range(0, 1; length=num_knots_beat)
knots_list_trend = range(0, last(X10s); length=num_knots_trend)
basis_beat = BSplineBasis(3, knots_list_beat)
basis_trend = BSplineBasis(3, knots_list_trend)
@model function full_spline_regression(x, y)
obs_len_s = last(x) - first(x)
# Set prior for beat interval
mean_beat_interval ~ LogNormal(log(1), 0.4)
rel_var_beat_interval ~ Exponential(0.1)
n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used
# Set variance prior.
σ² ~ truncated(Normal(0, 1); lower=0)
# Set intercept prior.
intercept ~ Normal(0, sqrt(3))
# Set spline prior
w_beat ~ MvNormal(zeros(length(basis_beat)), 1)
w_trend ~ MvNormal(zeros(length(basis_trend)), 1)
s_beat = Spline(basis_beat, w_beat)
s_trend = Spline(basis_trend, w_trend)
# Beat vector
# Instead of drawing 12 intervals, this shoud be the n_beats (Possibly drawn from a Poisson), but there may be a better way
beat_intervals ~ filldist(LogNormal(log(mean_beat_interval),
rel_var_beat_interval), 12)
beat_pos = [0; cumsum(beat_intervals)]
time_since_beat = get_time_since_event(x, beat_pos)
# Calculate all the mu terms.
mu = intercept .+ s_beat.(time_since_beat) .+ s_trend.(x)
for i in eachindex(y)
# Likelihood
y[i] ~ Normal(mu[i], σ²)
end
return (mu = mu, pred = y) # sampled with generated_quantities
end
full_spline_model = full_spline_regression(X10s, Y10s)
chains_full_spline_model = sample(full_spline_model, NUTS(), 100)
full_spline_model_post = full_spline_regression(X10s, Vector{Union{Missing, Float64}}(undef, length(X10s)))
samples_full_spline_model_post = generated_quantities(full_spline_model_post, chains_full_spline_model)
samples_full_spline_model_post_mu = gq2matrix(samples_full_spline_model_post, :mu)
plot_post_pred!(samples_full_spline_model_post_mu, X10s)
plot!(X10s, Y10s, lw = 1, color = :black)
So this clearly does not work (It also did not work with 1000 samples). I hope someone finds this problem interesting enough to try fixing the model.
I belive my main problem is how to sample the position of the heart beats, but other improvements and suggestions are very welcome as well.