Turing.jl and StaticArrays

I am having troubles using Turing.jl with a custom distribution using a StaticArray as internal parameter. Consider the simple code below, where the TestDistribution is effectively a simple multivariate normal 3D distribution with unit covariance:

using Distributions
using Random
using Bijectors
using LinearAlgebra
using StaticArrays

struct TestDistribution{T<:Real} <: Distribution{Multivariate,Continuous}
    a::SVector{3, T}
    data::Vector{T}
end

Distributions.length(d::TestDistribution) = length(d.data)
Distributions.size(d::TestDistribution) = (length(d.data),)
Distributions.eltype(::TestDistribution{T}) where {T} = T

function Distributions._logpdf(d::TestDistribution{T}, x::AbstractVector{S}) where {T,S}
    result = zero(T)
    f = length(d.data) ÷ length(d.a)
    for n ∈ 1:length(d.a):length(x)
        δ = d.a .- x[n:n+length(d.a)-1]
        result += -dot(δ, δ) / 2
    end
    result
end

function Distributions._rand!(rng::Random.AbstractRNG, d::TestDistribution, x::AbstractVector)
    f = length(d.data) ÷ length(d.a)
    x .= repeat(d.a, f) .+ randn(rng, length(d.a) * f)
end

insupport(::TestDistribution, x::AbstractVector) = true
Bijectors.bijector(::TestDistribution) = Identity{1}()

t = TestDistribution(SVector(2.0, 3.0, -1.0), rand(30))
data = rand(t)

using Turing

@model function testmodel(data)
    x ~ filldist(Uniform(-8, 8), 3)
    println(x)
    d = TestDistribution(SVector{3}(x), data)
    data ~ d
end

chain = sample(testmodel(data), SMC(), 10)

When running this code, I obtain a chain with zero variance on the parameters:

Chains MCMC chain (10×5×1 Array{Float64, 3}):

Log evidence      = -138.26892888869583
Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 5.37 seconds
Compute duration  = 5.37 seconds
parameters        = x[1], x[2], x[3]
internals         = lp, weight

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_se ⋯
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float6 ⋯

        x[1]   -0.3968    0.0000     0.0000    0.0000       NaN       NaN           Na ⋯
        x[2]   -1.1723    0.0000     0.0000    0.0000       NaN       NaN           Na ⋯
        x[3]   -0.4092    0.0000     0.0000    0.0000       NaN       NaN           Na ⋯
                                                                        1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        x[1]   -0.3968   -0.3968   -0.3968   -0.3968   -0.3968
        x[2]   -1.1723   -1.1723   -1.1723   -1.1723   -1.1723
        x[3]   -0.4092   -0.4092   -0.4092   -0.4092   -0.4092

In fact, only the last sampled value for x is considered.

StaticArrays is unrelated here, as using a Vector produces the same result. I suspect SMC selects the number of particles to sample based on the number of draws you have requested. Request more draws, and you at least get non-zero variance:

julia> chain = sample(testmodel(data), SMC(), 1_000)
Chains MCMC chain (1000×5×1 Array{Float64, 3}):

Log evidence      = -25.733669871912777
Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.42 seconds
Compute duration  = 0.42 seconds
parameters        = x[1], x[2], x[3]
internals         = lp, weight

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float64 

        x[1]    1.9258    0.0746     0.0024    0.0128    4.3333    1.2212       10.3669
        x[2]    2.4460    0.2640     0.0083    0.0466    4.2160    1.2249       10.0861
        x[3]   -0.1714    0.2859     0.0090    0.0328   41.5919    0.9999       99.5021

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        x[1]    1.8926    1.8926    1.8926    1.8926    2.0801
        x[2]    2.3277    2.3277    2.3277    2.3277    3.0491
        x[3]   -0.1800   -0.1800   -0.1800   -0.1800    0.1550

(though ESS and R-hat are bad; these draws should not be trusted)

On an unrelated note, setting data as a field of TestDistribution of the same eltype as the parameter vector a will not work with all samplers. In particular, when using gradient-based samplers with an operator-overloading AD such as ForwardDiff, the constructor will error, since data and a will have different eltypes. I recommend allowing data to have a different eltype, or, in this case, just store n=length(data) in TestDistribution.

1 Like