Modelling events in a time-series with random interval between (heart beats) with Turing.jl

(Edit: here is copy-paste-able code for the problematic model: heartbeat_question_code.jl · GitHub)

I’m trying to learn Turing.jl, and after a few simple models, I wanted to fit a more complex model, that I could not fit in R’s brms.

Seismocardiography is the study of the heart via chest vibrations. It is usually done with highfrequency accelerometers attached to the chest.

I’m trying to look at this effect with the relatively low-frequency accelerometer in my phone (50 Hz).

I may write op a blog post about it, if I (we) get it working.

Data is recorded by simply placing a phone on my chest while laying on my back.

The idea is to fit a spline to data recorded over multiple heartbeats.

The problem is that I do not have information about when each heart beat starts.

Data looks like this:

using Turing
using DataFrames
using CSV
using Tidier
using Plots
using StatsPlots

sample_data_path = "https://gist.githubusercontent.com/JohannesNE/df539760195619c642b0ea43b07c6ca3/raw/bb72096912b07b5221bd4a2202eac0ec7c9121a4/acc_apnea.csv"

acc = DataFrame(CSV.File(download(sample_data_path)))

@df acc plot(:sec, [:x :y :z],
    title = ["X" "Y" "Z"],
    layout=(3, 1),
    legend = false)

1

Most of the variation is in the Z axis (the normal axis the the phone screen, and hence the chest), so we’ll just use that axis for the model.

To make sampling faster, we start with just 10 seconds

acc10s = @filter(acc, sec < 10)

Y10s = @pull(acc10s, z)
X10s = @pull(acc10s, sec)

plot(X10s, Y10s)

2

Fitting a spline for just 1 beat

We can fit a spline to a single beat like this:

using BSplines

acc1s = @filter(acc, sec < 1)

Y1s = @pull(acc1s, z)
X1s = @pull(acc1s, sec)

num_knots = 30  
# We can probably place these knots more densely around 0.1 sec into the beat,
# where most of the variation is. For now we do it uniformly.
knots_list = range(0, 1; length=num_knots) 
basis = BSplineBasis(3, knots_list)

@model function simple_spline_regression(x, y) 	
	# Set variance prior.
    σ² ~ truncated(Normal(0, 1); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set spline prior
    w ~ MvNormal(zeros(length(basis)), 1)

    s = Spline(basis, w)

    # Calculate all the mu terms.
    mu = intercept .+ s.(x)

    # Likelihood
    for i in eachindex(y)
        y[i] ~ Normal(mu[i], σ²)
    end

    return (mu = mu, pred = y)
end

simple_spline_model = simple_spline_regression(X1s, Y1s)

chains_simple_spline_model = sample(simple_spline_model, NUTS(), 100)

To visualize the model fit, we sample mu from the posterior.

# Generate predictions of mu with a higher sample rate to get smooth curves.
X1s_pred = range(0, 1, 200)

simple_spline_model_post = simple_spline_regression(X1s_pred, Vector{Union{Missing, Float64}}(undef, length(X1s_pred)))

samples_simple_spline_model_post = generated_quantities(simple_spline_model_post, chains_simple_spline_model) 

# generated_quantities returns an array of tuples. This can be unpacked to a matrix
function gq2matrix(gq, selector)
    reduce(vcat, getindex(samp, selector)' for samp in gq[:,1])
end

samples_simple_spline_model_post_mu = gq2matrix(samples_simple_spline_model_post, :mu)

function plot_post_pred!(pred_draws::Array, x, n = 100)
    plot()
    for i in 1:n
        plot!(x, pred_draws[i,:], legend=false, alpha = 0.2)
    end

    epred_mean = mean.(eachcol(pred_draws))
    plot!(x, epred_mean, lw = 3, color = :blue)


    current()
end

plot_post_pred!(samples_simple_spline_model_post_mu, X1s_pred)
plot!(X1s, Y1s, lw = 1, color = :black)

3

Now, the goal is to fit a spline like the one above, but to multiple beats.
To do this, we need to learn when each beat starts—Except for the first, which we define as starting at sec=0.

It is probably reasonable to draw the intervals between beats from a log-normal distribution, e.g.:

plot(LogNormal(log(1), 0.05))

4

To get the time for each beat we could do cumsum of draws from this distribution, but this makes the position of beats dependent on all previous beats, which is probably not ideal. Is there a better way to sample these positions.

vline(cumsum(rand(LogNormal(log(1), 0.05), 10)))

5

After we get the timing of each beat, we use the following function to calculate time since latest beat, which we can model a spline over.

function get_time_since_event(time_vec, event_vec) 
    ann_index = zeros(eltype(event_vec), length(time_vec))
    i_ann::Int64 = 1

    for i in eachindex(time_vec)
        if i_ann < length(event_vec) && time_vec[i] > event_vec[i_ann + 1]
            i_ann += 1
        end

        ann_index[i] = time_vec[i] - event_vec[i_ann]
    end

    ann_index
end

We add a second spline over the entire 10 sec signal to catch any slower trend, though this may not be necessary for this, quite stationary, signal.

Y10s = @pull(acc10s, z)
X10s = @pull(acc10s, sec)

# Spline setup
num_knots_beat = 30  
num_knots_trend = 20  
# knots_list = quantile(X1s, range(0, 1; length=num_knots))
knots_list_beat = range(0, 1; length=num_knots_beat)
knots_list_trend = range(0, last(X10s); length=num_knots_trend)
basis_beat = BSplineBasis(3, knots_list_beat)
basis_trend = BSplineBasis(3, knots_list_trend)

@model function full_spline_regression(x, y) 
    obs_len_s = last(x) - first(x)   	
    # Set prior for beat interval
    mean_beat_interval ~ LogNormal(log(1), 0.4)
    rel_var_beat_interval ~ Exponential(0.1)
    n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used

	# Set variance prior.
    σ² ~ truncated(Normal(0, 1); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set spline prior
    w_beat ~ MvNormal(zeros(length(basis_beat)), 1)
    w_trend ~ MvNormal(zeros(length(basis_trend)), 1)

    s_beat = Spline(basis_beat, w_beat)
    s_trend = Spline(basis_trend, w_trend)

    # Beat vector
    # Instead of drawing 12 intervals, this shoud be the n_beats (Possibly drawn from a Poisson), but there may be a better way
    beat_intervals ~ filldist(LogNormal(log(mean_beat_interval),
                                        rel_var_beat_interval), 12) 
    beat_pos = [0; cumsum(beat_intervals)]
    time_since_beat = get_time_since_event(x, beat_pos)

    # Calculate all the mu terms.
    mu = intercept .+ s_beat.(time_since_beat) .+ s_trend.(x)
    
    for i in eachindex(y)
        # Likelihood
        y[i] ~ Normal(mu[i], σ²)
    end

    return (mu = mu, pred = y) # sampled with generated_quantities
end

full_spline_model = full_spline_regression(X10s, Y10s)

chains_full_spline_model = sample(full_spline_model, NUTS(), 100)

full_spline_model_post = full_spline_regression(X10s, Vector{Union{Missing, Float64}}(undef, length(X10s)))

samples_full_spline_model_post = generated_quantities(full_spline_model_post, chains_full_spline_model) 

samples_full_spline_model_post_mu = gq2matrix(samples_full_spline_model_post, :mu)

plot_post_pred!(samples_full_spline_model_post_mu, X10s)
plot!(X10s, Y10s, lw = 1, color = :black)

6

So this clearly does not work (It also did not work with 1000 samples). I hope someone finds this problem interesting enough to try fixing the model.

I belive my main problem is how to sample the position of the heart beats, but other improvements and suggestions are very welcome as well.

4 Likes

DISCLAIMER: I don’t know Turing at all, so take this with a grain of salt.

I think this might be the issue, because the function giving event_vec[i_ann] is piecewise-constant: in particular its gradient is uninformative. So if the Hamiltonian Monte-Carlo procedure starts by sampling a beat sequence that is slightly misaligned (like having two measured beats within the same sampled interval), I don’t see how it can recover from that.

Not sure how to fix that, but here’s a naive idea that I had. Could you define a latent “phase variable” at each time of your signal, basically representing “how long since the beginning of the last beat”? The marginal distribution of each of these could be taken as uniform over the maximum beat duration, or something similar. This decouples the beats from each other, and makes the whole procedure more nicely differentiable.

If what you want is the start of each beat, seems to me that you could make a model for a single beat (say the spline in your third image), then use that to run a matched filter against the “Z” data from your first plot. Every time your matched filter gives a peak, that’s the peak z for a beat, and you can back the start of each beat out from that.

I think you are correct that this is where it goes wrong.

I’m not sure I completely understand your proposal.
If I independently sample “time since last beat” for each point in my signal, I’ll loose the information that these observations are consecutive and always have 20 ms between them.

You’re right. Maybe “time since last beat” could follow the continuous-state equivalent of a discrete Markov chain with transition structure

1 -> 2 -> 3 -> ... -> T -> 1

Except the T has to be able to vary, so transitions from every state to the reset state could be allowed. Or something to that effect, not sure of the details

1 Like

I have never worked with Markov chain models, so I would not know where to start.

I just updated the OP with at link to a standalone script for just the problematic model, in case someone wants to have a go at making it work (heartbeat_question_code.jl · GitHub).

That would probably work, but I really want to do it in one model, to get all that sweet Bayesian inference for the beat position as well :smiling_face:

HMC won’t work well trying to get the right period between beats (when there are multiple that distinct periods that could work) but I bet Pigeons.jl would be able to sample successfully.

1 Like

Wow, that was impressive. Still needs a lot of finetuning, but with Pigeons.jl, it actually samples the model.

I tightened my prior for the heart rate: mean_beat_interval ~ LogNormal(log(1.1), 0.2), and after that, it ran.

using Pigeons

pt = pigeons(target = TuringLogPotential(full_spline_model),
        record = [traces],
        n_rounds = 7,
        multithreaded = true)

chains_full_spline_model = Chains(pt)

full_spline_model_post = full_spline_regression(X10s, Vector{Union{Missing, Float64}}(undef, length(X10s)))

samples_full_spline_model_post = generated_quantities(full_spline_model_post, chains_full_spline_model) 

samples_full_spline_model_post_mu = gq2matrix(samples_full_spline_model_post, :mu)

plot_post_pred!(samples_full_spline_model_post_mu, X10s)
plot!(X10s, Y10s, lw = 1, color = :black)

7

Prediction of first second using entire model:

full_spline_model_post_1s = full_spline_regression(X1s_pred, Vector{Union{Missing, Float64}}(undef, length(X1s_pred)))

samples_full_spline_model_post_1s = generated_quantities(full_spline_model_post_1s, chains_full_spline_model) 

samples_full_spline_model_post_mu_1s = gq2matrix(samples_full_spline_model_post_1s, :mu)

plot_post_pred!(samples_full_spline_model_post_mu_1s, X1s_pred)
plot!(X1s, Y1s, lw = 1, color = :black)

8

2 Likes

A few immediate observations:

n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used

will cause unnecessary “annoyance” for the sampler, so I’d just comment it out if it’s not used:)

y[i] ~ Normal(mu[i], σ²)

Normal is parameterized by mean and standard deviation, not variance; is the above intended?

For the inference issue itself, I’d recommend a few things:

First make sure that you can recover simulated data (i.e. in the case where you have perfect model specification). You can do this as follows:

# Prior check.
params = rand(full_spline_model)
# Condition on the `params`.
generating_model = full_spline_regression(X10s) | params
# Simulate data.
y_fake = generating_model().pred
# Infer!
chains_full_spline_prior_model = sample(generating_model | (y = y_fake,), NUTS(), 100)

The second thing is to just fix the variables you believe to be causing issues to the “true” values and see if you can at least infer the rest of the model:

# Fix the variables related to the beat intervals as we're worried they might be causing issues.
model_simpler = Turing.fix(
    full_spline_regression(X10s),
    mean_beat_interval = 1, rel_var_beat_interval = 0.05, beat_intervals = collect(1.0:12)
)

Then you run inference on this to check if you can recover the parameters. With this you should at least be able to identify what exactly is causing the inference issue. It might indeed be the case that you’ll need something like Pigeons.jl for the full model, but going through these steps before first is probably a good idea. EDIT: Oh nvm :sweat_smile:

Note in the above, I used the following slightly modified model:

@model function full_spline_regression(x)  # `y` has been removed as an argument
    obs_len_s = last(x) - first(x)   	
    # Set prior for beat interval
    mean_beat_interval ~ LogNormal(log(1), 0.4)
    rel_var_beat_interval ~ Exponential(0.1)
    # n_beats ~ Poisson((1/mean_beat_interval) * obs_len_s) # Not currently used

	# Set variance prior.
    σ² ~ truncated(Normal(0, 1); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set spline prior
    w_beat ~ MvNormal(zeros(length(basis_beat)), 1)
    w_trend ~ MvNormal(zeros(length(basis_trend)), 1)

    s_beat = Spline(basis_beat, w_beat)
    s_trend = Spline(basis_trend, w_trend)

    # Beat vector
    # Instead of drawing 12 intervals, this shoud be the n_beats (Possibly drawn from a Poisson), but there may be a better way
    beat_intervals ~ filldist(LogNormal(log(mean_beat_interval),
                                        rel_var_beat_interval), 12) 
    beat_pos = [0; cumsum(beat_intervals)]
    time_since_beat = get_time_since_event(x, beat_pos)

    # Calculate all the mu terms.
    mu = intercept .+ s_beat.(time_since_beat) .+ s_trend.(x)

    # Vectorized version: a bit faster + nicer to work with when using `condition`, etc.
    y ~ MvNormal(mu, σ²)
    # for i in eachindex(y)
    #     # Likelihood
    #     y[i] ~ Normal(mu[i], σ²)
    # end

    return (mu = mu, pred = y) # sampled with generated_quantities
end

Just making y a variable makes it so much easier to make use of the condition / | syntax to sample from the model instead of messing around with missings:)

1 Like

Awesome:)

This is so much nicer! Is it documented anywhere? I looked through a lot of docs and forum post just to get to the missing approach.

Just in the brainstorming phase of I know enough to be dangerous, this research seems like it has similarities with other research I have come across. The other research was applied to ECG data if I recall.

It is, but unfortunately it’s a bit spread out at the moment :confused:

EDIT: Note that it’s not always a perfect replacement for passing the observation in as an argument (see the docstring of condition).

1 Like