Help speeding up and reducing numerical errors for simple epidemic model with Turing.jl

I’m trying to fit a simple epidemic model that starts at some unknown time and then the prevalence of infection progresses according to a logistic curve. The data provided are Binomial with a known sample size. I generated some toy data and am trying to get Turing.jl to do inference but it gives me a ton of warnings and takes a long time.

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47

I’d also ideally like to do this repeatedly many times so I need the code to run very fast.

Here’s the simple model.

using Random
using StatsFuns
using Turing
# True parameters: β = 0.015008, p0 = 0.01, start_time = 100
# Data
n = 200

positive_counts = [1.0, 1.0, 1.0, 2.0, 4.0, 4.0, 3.0, 2.0, 2.0, 2.0, 0.0, 4.0, 3.0, 2.0, 2.0, 
2.0, 4.0, 2.0, 1.0, 3.0, 3.0, 1.0, 3.0, 1.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 0.0, 2.0, 8.0, 2.0, 5.0, 
3.0, 0.0, 2.0, 2.0, 4.0, 1.0, 3.0, 2.0, 5.0, 2.0, 5.0, 3.0, 3.0, 3.0, 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 
2.0, 1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 2.0, 2.0, 1.0, 2.0, 1.0, 4.0, 2.0, 0.0, 1.0, 2.0, 1.0, 3.0, 3.0, 
2.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 3.0, 0.0, 2.0, 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 3.0, 
3.0, 2.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 3.0, 1.0, 4.0, 1.0, 1.0, 3.0, 1.0, 2.0, 4.0, 2.0, 1.0, 1.0, 
4.0, 1.0, 0.0, 2.0, 1.0, 5.0, 2.0, 3.0, 3.0, 4.0, 0.0, 6.0, 4.0, 3.0, 4.0, 3.0, 1.0, 4.0, 5.0, 5.0, 
1.0, 2.0, 5.0, 4.0, 4.0, 4.0, 7.0, 3.0, 2.0, 2.0, 7.0, 5.0, 3.0, 2.0, 3.0, 4.0, 4.0, 3.0, 6.0, 1.0, 
6.0, 4.0, 1.0, 4.0, 8.0, 2.0, 2.0, 3.0, 3.0, 2.0, 4.0, 3.0, 8.0, 2.0, 7.0, 7.0, 5.0, 8.0, 9.0, 6.0, 
8.0, 8.0, 4.0, 5.0, 8.0, 2.0, 8.0, 4.0, 6.0, 6.0, 13.0, 9.0, 12.0, 8.0, 6.0, 8.0, 4.0, 5.0, 9.0, 4.0, 
9.0, 8.0, 4.0, 7.0, 10.0, 11.0, 12.0, 8.0, 11.0, 9.0, 12.0, 12.0, 13.0, 13.0, 10.0, 10.0, 11.0, 15.0, 
12.0, 11.0, 4.0, 7.0, 7.0, 9.0, 13.0, 11.0, 13.0, 16.0, 11.0, 9.0, 16.0, 12.0, 11.0, 15.0, 11.0, 9.0, 
16.0, 11.0, 13.0, 16.0, 12.0, 17.0, 11.0, 20.0, 16.0, 16.0, 19.0, 15.0, 12.0, 13.0, 12.0, 14.0, 15.0, 
17.0, 22.0, 20.0, 16.0, 20.0, 17.0, 18.0, 17.0, 19.0, 16.0, 19.0, 17.0, 23.0, 24.0, 22.0, 19.0, 19.0, 
20.0, 19.0, 23.0, 28.0, 20.0, 20.0, 24.0, 26.0, 25.0, 26.0, 21.0, 34.0, 29.0, 28.0, 23.0, 29.0, 28.0, 
27.0, 36.0, 34.0, 29.0, 22.0, 17.0, 29.0, 28.0, 23.0, 39.0, 20.0, 28.0, 31.0, 23.0, 37.0, 31.0, 39.0, 49.0]

test_times = collect(0.0:1.0:299.0)

# Model
@model logistic_epidemic(n, positive_counts, test_times) = begin
    β ~ Gamma(0.1, 10) # transmission rate
    p0 ~ Beta(1, 1) # initial prevalence
    start_time ~ Uniform(0, test_times[end]) # epidemic start time
    for (i, t) in enumerate(test_times)
        p = logistic(-β * (t - start_time) + log(1 / p0 - 1))
        positive_counts[i] ~ Binomial(n, p)
    end
end

# Inference
chain = sample(logistic_epidemic(n, positive_counts, test_times), NUTS(1000, 0.95), 4000)

Try using the following instead:

@model logistic_epidemic(n, positive_counts, test_times) = begin
    β ~ Gamma(0.1, 10) # transmission rate
    p0 ~ Beta(1, 1) # initial prevalence
    z = log(1 / p0 - 1)
    start_time ~ Uniform(0, test_times[end]) # epidemic start time
    for (i, t) in enumerate(test_times)
        p = -β * (t - start_time) + z
        positive_counts[i] ~ BinomialLogit(n, p)
    end
end

Thanks! Unfortunately this still causes a lot of numerical warnings.

The priors you use are pretty uninformative, is this intensionally?

Yes, I was intending to use uninformative priors.

@Kai_Xu do you have a good intuition why this model might be problematic for HMC?

Actually, I found a problem. The BinomialLogit function is parameterized such that it should be positive_counts[i] ~ BinomialLogit(n, -p). The parameters were restricted to the non-negative domain so it couldn’t get a good fit previously. It works now. Do you have any other tips for making it faster?

I’ve also tried using ADVI but I get NaN when I use rand.

advi = ADVI(10, 1000)
q = vi(logistic_epidemic(n, positive_counts, test_times), advi)
julia> q
Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Inverse{Bijectors.Logit{0,Float64},0},Inverse{Bijectors.TruncatedBijector{0,Float64,Float64},0}},3},Multivariate}(
dist: DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[NaN, NaN, NaN], σ=[NaN, NaN, NaN])
transform: Stacked{Tuple{Bijectors.Exp{0},Inverse{Bijectors.Logit{0,Float64},0},Inverse{Bijectors.TruncatedBijector{0,Float64,Float64},0}},3}((Bijectors.Exp{0}(), Inverse{Bijectors.Logit{0,Float64},0}(Bijectors.Logit{0,Float64}(0.0, 1.0)), Inverse{Bijectors.TruncatedBijector{0,Float64,Float64},0}(Bijectors.TruncatedBijector{0,Float64,Float64}(0.0, 299.0))), (1:1, 2:2, 3:3))
)


julia> rand(q)
3-element Array{Float64,1}:
 NaN
 NaN
 NaN

Also, just a note that I modified the model slightly to include a max.

@model logistic_epidemic(n, positive_counts, test_times) = begin
    β ~ Gamma(0.1, 10) # transmission rate
    p0 ~ Beta(1, 1) # initial prevalence
    start_time ~ Uniform(0, test_times[end]) # epidemic start time
    z = log(1 / p0 - 1)
    for (i, t) in enumerate(test_times)
        p = max.(p0, -β * (t - start_time) + z)
        positive_counts[i] ~ BinomialLogit(n, -p) # note we are using -p here
    end
end

I’ll have a look at it in the next days.

The last model I put is not correct. Here’s what I intended but this breaks with an inexact error.

using Random
using StatsFuns
using Plots
using Turing

# True parameters
const β_true = 0.015008
const p0_true = 0.01
const Γ_true = 99.0

# Data
const n = 200

const W = [1.0, 1.0, 1.0, 2.0, 4.0, 4.0, 3.0, 2.0, 2.0, 2.0, 0.0, 4.0, 3.0, 2.0, 2.0, 
2.0, 4.0, 2.0, 1.0, 3.0, 3.0, 1.0, 3.0, 1.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 0.0, 2.0, 8.0, 2.0, 5.0, 
3.0, 0.0, 2.0, 2.0, 4.0, 1.0, 3.0, 2.0, 5.0, 2.0, 5.0, 3.0, 3.0, 3.0, 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 
2.0, 1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 2.0, 2.0, 1.0, 2.0, 1.0, 4.0, 2.0, 0.0, 1.0, 2.0, 1.0, 3.0, 3.0, 
2.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 5.0, 3.0, 0.0, 2.0, 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 3.0, 
3.0, 2.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 3.0, 1.0, 4.0, 1.0, 1.0, 3.0, 1.0, 2.0, 4.0, 2.0, 1.0, 1.0, 
4.0, 1.0, 0.0, 2.0, 1.0, 5.0, 2.0, 3.0, 3.0, 4.0, 0.0, 6.0, 4.0, 3.0, 4.0, 3.0, 1.0, 4.0, 5.0, 5.0, 
1.0, 2.0, 5.0, 4.0, 4.0, 4.0, 7.0, 3.0, 2.0, 2.0, 7.0, 5.0, 3.0, 2.0, 3.0, 4.0, 4.0, 3.0, 6.0, 1.0, 
6.0, 4.0, 1.0, 4.0, 8.0, 2.0, 2.0, 3.0, 3.0, 2.0, 4.0, 3.0, 8.0, 2.0, 7.0, 7.0, 5.0, 8.0, 9.0, 6.0, 
8.0, 8.0, 4.0, 5.0, 8.0, 2.0, 8.0, 4.0, 6.0, 6.0, 13.0, 9.0, 12.0, 8.0, 6.0, 8.0, 4.0, 5.0, 9.0, 4.0, 
9.0, 8.0, 4.0, 7.0, 10.0, 11.0, 12.0, 8.0, 11.0, 9.0, 12.0, 12.0, 13.0, 13.0, 10.0, 10.0, 11.0, 15.0, 
12.0, 11.0, 4.0, 7.0, 7.0, 9.0, 13.0, 11.0, 13.0, 16.0, 11.0, 9.0, 16.0, 12.0, 11.0, 15.0, 11.0, 9.0, 
16.0, 11.0, 13.0, 16.0, 12.0, 17.0, 11.0, 20.0, 16.0, 16.0, 19.0, 15.0, 12.0, 13.0, 12.0, 14.0, 15.0, 
17.0, 22.0, 20.0, 16.0, 20.0, 17.0, 18.0, 17.0, 19.0, 16.0, 19.0, 17.0, 23.0, 24.0, 22.0, 19.0, 19.0, 
20.0, 19.0, 23.0, 28.0, 20.0, 20.0, 24.0, 26.0, 25.0, 26.0, 21.0, 34.0, 29.0, 28.0, 23.0, 29.0, 28.0, 
27.0, 36.0, 34.0, 29.0, 22.0, 17.0, 29.0, 28.0, 23.0, 39.0, 20.0, 28.0, 31.0, 23.0, 37.0, 31.0, 39.0, 49.0]

const t = collect(0.0:1.0:299.0)

const n_samples = length(W)

function epidemic_loglikeli(β, p0, Γ, n, W, t) 
    pe = max.(p0, logistic.(β .* t .- β * Γ .+ logit(p0)))
    return sum(W .* log.(pe) .+ (n .- W) .* log.(1 .- pe))
end

# Model
@model logistic_epidemic(n, W, t) = begin
    β ~ Gamma(0.1, 10) # transmission rate
    p0 ~ Beta(1, 1) # initial prevalence
    Γ ~ Geometric(0.01) # epidemic start time
    z = logit(p0)
    for (i, ti) in enumerate(t)
        # W[i] ~ Binomial(n, max(p0, logistic.(β * ti - β * Γ + z)))
        W[i] ~ BinomialLogit(n, max(z, β * ti - β * Γ + z))
    end
end

# Inference
chain = sample(logistic_epidemic(n, W, t), NUTS(1000, 0.95), 2000)