Help with slow Turing.jl inference speeds for large number of observations

Hi all,

I am new to Turing.jl and bayesian inference in general. I have specified the following model:

@model function switch_model(r, m, n)

    hm ~ Uniform(1, 10)
    am ~ Uniform(0, 10)
    κm ~ Uniform(0, 1)
    Δm ~ Uniform(0, 1)

    z = tzeros(Int, n)

    for i in 1:n
        z[i] ~ Bernoulli(κm)
        if z[i] == 1
            Pm = f(r[i],hm,am)
            m[i] ~ Bernoulli(Pm)
        else
            m[i] ~ Bernoulli(Δm)
        end
    end

end;

The length of r is approx 250k observations. It contains values from 1-8 i.e. r = rand(1:8,250000). The function f is a kind of sigmoid which takes values between zero and one (interpretated as a probability Pm). The observations m is a vector of booleans. I am not normalising values in r and would prefer not to unless necessary so I can use the inferred parameters as they come out of the inference process.

I am aware that this is a large number of observations but I would expect Turing to be quite quick given that it is a simple model with a low number of parameters. In general I am unsure what solver to use - I have tried NUTs using

model = switch_model(r, m, n);
chn = sample(model, NUTS(), 1000)

This is very slow - in fact I think it breaks my kernel most of the time. I thought this was likely because of the intermediate z[i] in the logic which is been tracked by the solver. I therefore tried to rewrite the model:

@model function switch_experiment_mm(r, m, n)

    hm ~ Uniform(1, 10)
    am ~ Uniform(0, 10)
    κm ~ Uniform(0, 1)
    Δm ~ Uniform(0, 1)

    for i in 1:n
        Pm = f(r[i],hm,am)
        m[i] ~ MixtureModel(Bernoulli[
                            Bernoulli(Pm),
                            Bernoulli(Δm)], [κm, 1-κm])
    end

end;

As I think my logic is equivalent to some sort of mixture model. However this also doesn’t work and seems to break the kernel again.

Am I doing something fundamentally wrong or is this just a case of large number of observations = slow?

I remember having to use Turing.@addlogprob! in the past to speed up large models. It basically allows setting the log-prob directly instead of defining hundreds or thousands of variables. Search this forum to find examples of usage.

If I am not mistaken, this is true for NUTS. I don’t think it handles mixtures well, unless you marginalize explicitly. I think MixtureModel should help, but there might be other problems. Have you tried a parameter recovery with a small set of data generated from your model? That might help you identify whether the problem is in your implementation or a scaling issue.

It’s weird that your code even worked. NUTS does not work with discrete variables.

Your rewritten model is more sensible but will indeed experience speed issues. It’s better to vectorize everything. Turing is essentially a thin wrapper around Julia, so anything that would speed up Julia will also work for Turing. Also, try with different AD backends. That makes a big difference.