Efficient vectorization in Turing.jl

I have a fairly simple hierarchical model where I have 2 population parameters a,b. b is also an individual level parameters such that the b_i are sampled from a distribution parametrized by b. a and b_i govern the shape of an individual-level function f(t; a, b_i) from which the individual-level data is sampled using a Bernoulli for each timepoint t.

What’s the best way to vectorize this? Since the b_i are i.i.d., filldist should do the job. However, the data for each cell is not i.i.d. (since f(t; a,b_i) is different for each individual), so it seems like I need arraydist here.

So the model boils down to:

a_population ~ P(.)
b_population ~ Q(.)
b_individual ~ filldist(Q(.;b), num_individuals)
individual_responses(a,b) ~ arraydist(map(t-> Bernoulli(f(t;a,b)),1:num_timepoints)
population_responses ~ arraydist(map(i->individual_responses(a,b_individual[i]),1:num_cells))

where I’m using P(.) and Q(.) to denote some arbitrary continuous valued distributions.
Is using map here efficient? I’m assuming it’s better than array comprehension but I’m not sure how well Turing deals with it. All the parameters are continuous, so I’m hoping to use NUTS with reverse differentiation. The full model will have more parameters but this is a minimal example.

Tested using the latest releases of Turing and DynamicPPL:

using Random, Turing, DynamicPPL, Chairmarks

@model function loop(means)
    x = Vector{Float64}(undef, length(means))
    for i in eachindex(means)
        x[i] ~ Normal(means[i], 1.0)
    end
end

@model function mvnormal(means)
    x ~ MvNormal(means, I)
end

@model function arraydist_map(means)
    x ~ arraydist(map(m -> Normal(m, 1.0), means))
end

@model function arraydist_comp(means)
    x ~ arraydist([Normal(means[i], 1.0) for i in eachindex(means)])
end

@model function arraydist_broadcast(means)
    x ~ arraydist(Normal.(means, 1.0))
end

@model function product_comp(means)
    x ~ product_distribution([Normal(means[i], 1.0) for i in eachindex(means)])
end

means = randn(Xoshiro(468), 10)

for model_gen in [loop, mvnormal, arraydist_map, arraydist_comp, arraydist_broadcast, product_comp]
    model = model_gen(means)
    varinfo = VarInfo(model)
    t = @be DynamicPPL.evaluate!!($model, $varinfo)
    println("$(model.f) => $(median(t).time)")
end

yields

loop => 1.2680869565217391e-6
mvnormal => 2.2128030303030305e-7
arraydist_map => 2.9042000000000003e-7
arraydist_comp => 2.8105660377358494e-7
arraydist_broadcast => 2.9542000000000003e-7
product_comp => 3.653846153846154e-7

Versions:

(ppl) pkg> st
Status `~/ppl/Project.toml`
  [0ca39b1e] Chairmarks v1.3.1
  [366bfd00] DynamicPPL v0.36.12 `dppl`
  [fce5fe82] Turing v0.39.3 `lib`
  [9a3f8284] Random v1.11.0

In the special case where they are normal distributions, MvNormal seems to be the fastest. But in general, the three arraydist models take essentially the same time (if you rerun, sometimes one is faster than the other). So I would not worry too much about the exact way you construct the array of distributions :slight_smile:

As for AD, I don’t expect that any of the above formulations should cause problems with ReverseDiff, Mooncake, or Enzyme. If they do, feel free to open an issue, and we can look into it!

2 Likes