I am having troubles using Turing.jl with a custom distribution using a StaticArray as internal parameter. Consider the simple code below, where the TestDistribution
is effectively a simple multivariate normal 3D distribution with unit covariance:
using Distributions
using Random
using Bijectors
using LinearAlgebra
using StaticArrays
struct TestDistribution{T<:Real} <: Distribution{Multivariate,Continuous}
a::SVector{3, T}
data::Vector{T}
end
Distributions.length(d::TestDistribution) = length(d.data)
Distributions.size(d::TestDistribution) = (length(d.data),)
Distributions.eltype(::TestDistribution{T}) where {T} = T
function Distributions._logpdf(d::TestDistribution{T}, x::AbstractVector{S}) where {T,S}
result = zero(T)
f = length(d.data) ÷ length(d.a)
for n ∈ 1:length(d.a):length(x)
δ = d.a .- x[n:n+length(d.a)-1]
result += -dot(δ, δ) / 2
end
result
end
function Distributions._rand!(rng::Random.AbstractRNG, d::TestDistribution, x::AbstractVector)
f = length(d.data) ÷ length(d.a)
x .= repeat(d.a, f) .+ randn(rng, length(d.a) * f)
end
insupport(::TestDistribution, x::AbstractVector) = true
Bijectors.bijector(::TestDistribution) = Identity{1}()
t = TestDistribution(SVector(2.0, 3.0, -1.0), rand(30))
data = rand(t)
using Turing
@model function testmodel(data)
x ~ filldist(Uniform(-8, 8), 3)
println(x)
d = TestDistribution(SVector{3}(x), data)
data ~ d
end
chain = sample(testmodel(data), SMC(), 10)
When running this code, I obtain a chain with zero variance on the parameters:
Chains MCMC chain (10×5×1 Array{Float64, 3}):
Log evidence = -138.26892888869583
Iterations = 1:1:10
Number of chains = 1
Samples per chain = 10
Wall duration = 5.37 seconds
Compute duration = 5.37 seconds
parameters = x[1], x[2], x[3]
internals = lp, weight
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_se ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float6 ⋯
x[1] -0.3968 0.0000 0.0000 0.0000 NaN NaN Na ⋯
x[2] -1.1723 0.0000 0.0000 0.0000 NaN NaN Na ⋯
x[3] -0.4092 0.0000 0.0000 0.0000 NaN NaN Na ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x[1] -0.3968 -0.3968 -0.3968 -0.3968 -0.3968
x[2] -1.1723 -1.1723 -1.1723 -1.1723 -1.1723
x[3] -0.4092 -0.4092 -0.4092 -0.4092 -0.4092
In fact, only the last sampled value for x
is considered.