Vectorizing observations from multivariate distribution in Turing

I have a model where the data are drawn from multinomial distributions. The parameters of those distributions (number of draws and probabilities) are different at each data point. MWE below:

ndata = 100
k = 5
n = 10
p = rand(k, ndata)
p ./= sum(p, dims=1)
counts = [rand(Multinomial(n, p[:, i])) for i in 1:100]

@model function test_multinomial(counts, n, k)
    p ~ filldist(Dirichlet(ones(k)), length(counts))
    for i in 1:length(counts)
        counts[i] ~ Multinomial(n, p[:, i])
    end
end
sample(test_multinomial(counts, n, 5), NUTS(), 100)

My question: how can I vectorize the observation statement to speed up sampling with reverse-mode autodiff? I’ve tried this variant using arraydist:

counts = [rand(Multinomial(n, p[:, i])) for i in 1:100]
counts = reduce(hcat, counts)

@model function test_multinomial_vec(counts, n, k)
    p ~ filldist(Dirichlet(ones(k)), length(counts))
    counts ~ arraydist([Multinomial(n, p[:, i]) for i in 1:size(counts, 2)])
end
sample(test_multinomial(counts, n, 5), NUTS(), 100)

but get this method error:

Error message
MethodError: no method matching ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}})
Closest candidates are:
  ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}(::AbstractArray{V,N}, !Matched::AbstractArray{D,N}, !Matched::Array{ReverseDiff.AbstractInstruction,1}) where {V, D, N, VA, DA} at C:\Users\sam.urmy\.julia\packages\ReverseDiff\jFRo1\src\tracked.jl:74
convert(::Type{Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}}, ::Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at multinomial.jl:53
setindex!(::Array{Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},1}, ::Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::Int64) at array.jl:847
collect_to_with_first!(::Array{Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},1}, ::Multinomial{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::Base.Generator{UnitRange{Int64},var"#459#461"{Int64,ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}}}, ::Int64) at array.jl:709
collect(::Base.Generator{UnitRange{Int64},var"#459#461"{Int64,ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}}}) at array.jl:691
#458 at stheno_turing_example.jl:338 [inlined]
(::var"#458#460")(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#458#460",(:counts, :n, :k),(),(),Tuple{Array{Int64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{NamedTuple{(:p,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{DistributionsAD.VectorOfMultivariate{Continuous,Dirichlet{Float64},FillArrays.Fill{Dirichlet{Float64},1,Tuple{Base.OneTo{Int64}}}},1},Array{DynamicPPL.VarName{:p,Tuple{}},1},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Array{Set{DynamicPPL.Selector},1}}}},ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}},Array{Base.RefValue{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}},1}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ReverseDiffAD{false},(),AdvancedHMC.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:p,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{DistributionsAD.VectorOfMultivariate{Continuous,Dirichlet{Float64},FillArrays.Fill{Dirichlet{Float64},1,Tuple{Base.OneTo{Int64}}}},1},Array{DynamicPPL.VarName{:p,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}, ::DynamicPPL.DefaultContext, ::Array{Int64,2}, ::Int64, ::Int64) at none:0
macro expansion at model.jl:0 [inlined]
_evaluate at model.jl:145 [inlined]
evaluate_threadsafe(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#458#460",(:counts, :n, :k),(),(),Tuple{Array{Int64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.VarInfo{NamedTuple{(:p,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{DistributionsAD.VectorOfMultivariate{Continuous,Dirichlet{Float64},FillArrays.Fill{Dirichlet{Float64},1,Tuple{Base.OneTo{Int64}}}},1},Array{DynamicPPL.VarName{:p,Tuple{}},1},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Array{Set{DynamicPPL.Selector},1}}}},ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ReverseDiffAD{false},(),AdvancedHMC.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:p,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{DistributionsAD.VectorOfMultivariate{Continuous,Dirichlet{Float64},FillArrays.Fill{Dirichlet{Float64},1,Tuple{Base.OneTo{Int64}}}},1},Array{DynamicPPL.VarName{:p,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}, ::DynamicPPL.DefaultContext) at model.jl:135
Model at model.jl:96 [inlined]
Model at model.jl:84 [inlined]
(::Turing.Core.var"#f#24"{DynamicPP...

Is this a bug, or am I misunderstanding how to use arraydist?

Will have a look at it.

1 Like