Help with slow Turing.jl inference speeds for large number of observations

Hi all,

I am new to Turing.jl and bayesian inference in general. I have specified the following model:

@model function switch_model(r, m, n)

    hm ~ Uniform(1, 10)
    am ~ Uniform(0, 10)
    κm ~ Uniform(0, 1)
    Δm ~ Uniform(0, 1)

    z = tzeros(Int, n)

    for i in 1:n
        z[i] ~ Bernoulli(κm)
        if z[i] == 1
            Pm = f(r[i],hm,am)
            m[i] ~ Bernoulli(Pm)
        else
            m[i] ~ Bernoulli(Δm)
        end
    end

end;

The length of r is approx 250k observations. It contains values from 1-8 i.e. r = rand(1:8,250000). The function f is a kind of sigmoid which takes values between zero and one (interpretated as a probability Pm). The observations m is a vector of booleans. I am not normalising values in r and would prefer not to unless necessary so I can use the inferred parameters as they come out of the inference process.

I am aware that this is a large number of observations but I would expect Turing to be quite quick given that it is a simple model with a low number of parameters. In general I am unsure what solver to use - I have tried NUTs using

model = switch_model(r, m, n);
chn = sample(model, NUTS(), 1000)

This is very slow - in fact I think it breaks my kernel most of the time. I thought this was likely because of the intermediate z[i] in the logic which is been tracked by the solver. I therefore tried to rewrite the model:

@model function switch_experiment_mm(r, m, n)

    hm ~ Uniform(1, 10)
    am ~ Uniform(0, 10)
    κm ~ Uniform(0, 1)
    Δm ~ Uniform(0, 1)

    for i in 1:n
        Pm = f(r[i],hm,am)
        m[i] ~ MixtureModel(Bernoulli[
                            Bernoulli(Pm),
                            Bernoulli(Δm)], [κm, 1-κm])
    end

end;

As I think my logic is equivalent to some sort of mixture model. However this also doesn’t work and seems to break the kernel again.

Am I doing something fundamentally wrong or is this just a case of large number of observations = slow?

I remember having to use Turing.@addlogprob! in the past to speed up large models. It basically allows setting the log-prob directly instead of defining hundreds or thousands of variables. Search this forum to find examples of usage.

If I am not mistaken, this is true for NUTS. I don’t think it handles mixtures well, unless you marginalize explicitly. I think MixtureModel should help, but there might be other problems. Have you tried a parameter recovery with a small set of data generated from your model? That might help you identify whether the problem is in your implementation or a scaling issue.