Custom univariate distribution in Turing.jl

I’m trying to use customized univariate distribution in Turing.jl.
(The domain is 0 ~ +Inf)

I implemented Distributions.pdf, logpdf, rand, minimum, maximum etc, referring the document (Advanced Usage).
I’m pretty sure that every methods are working fine.

However, sampling from the following model returns the following error.
At least I understand that this error is coming from ForwardDiff.jl but I have no clue where to fix…

Any suggestions are welcomed.
Thank you in advance.


Implementation of the custom distribution

import Distributions:@distr_support, minimum, maximum, pdf, cdf, quantile, logpdf, rand
mutable struct Diffusion{T<:Real} <: ContinuousUnivariateDistribution
    D::T
    δ::T
    ϵ::T
    Diffusion{T}(D::T, δ::T, ϵ::T) where {T <: Real} = new{T}(D, δ, ϵ)
end

function Diffusion(D::T, δ::T, ϵ::T; check_args = true) where {T <: Real}
    check_args && Distributions.@check_args(
        Diffusion, D > zero(D) && δ > zero(δ) && ϵ >= zero(ϵ)
    )
    return Diffusion{T}(D, δ, ϵ)
end

Distributions.@distr_support Diffusion 0 +Inf
Distributions.minimum(d::Diffusion) = 0.0
Distributions.maximum(d::Diffusion) = +Inf

Distributions.cdf(d::Diffusion, x::Real) where T <: Real =
    1 - exp(-x^2 / 4(d.D * d.δ + d.ϵ^2))
Distributions.quantile(d::Diffusion, p) = sqrt(-4(d.D * d.δ + d.ϵ^2) * log(1 - p))
Distributions.pdf(d::Diffusion, x::Float64) =
    0 ≤ x ? x / 2(d.D * d.δ + d.ϵ^2) * exp(-x^2 / 4(d.D * d.δ + d.ϵ^2)) : zero(x)
Distributions.logpdf(d::Diffusion, x::AbstractArray{<:Real}) =
    0 ≤ x ? log(x) - log(2(d.D * d.δ + d.ϵ^2)) - (x^2 / (4(d.D * d.δ + d.ϵ^2))) : zero(x)
Distributions.quantile(d::Diffusion, x::AbstractVector{<:Real}) = sqrt(-4(d.D * d.δ + d.ϵ^2) * log(1 - x))
Distributions.rand(d::Diffusion, rng::AbstractVector{<:Real}) =
    Distributions.quantile(d::Diffusion, rng)

Code for Turing.jl

@model MyModel(x) = begin
    a ~ Gamma(1., 1.)
    N = length(x)
    for n in 1:N
        x[n] ~ Diffusion(a, 0.02, 0.03)
    end
end

# run sampler
x = rand(Diffusion(0.5, 0.02, 0.03), 1000)
sample(MyModel(x), NUTS(0.65), 1000)

Error messages

MethodError: no method matching Diffusion(::ForwardDiff.Dual{ForwardDiff.Tag{Turing.Core.var"#f#7"{DynamicPPL.VarInfo{NamedTuple{(:a,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:a},Int64},Array{Gamma{Float64},1},Array{DynamicPPL.VarName{:a},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64},DynamicPPL.Model{var"##inner_function#593#21",NamedTuple{(:x,),Tuple{Array{Float64,1}}},DynamicPPL.ModelGen{(:x,),var"###MyModel#607",NamedTuple{(),Tuple{}}},Val{()}},DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:a,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:a},Int64},Array{Gamma{Float64},1},Array{DynamicPPL.VarName{:a},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}},Float64},Float64,1}, ::Float64, ::Float64)
Closest candidates are:
  Diffusion(::T, !Matched::T, !Matched::T; check_args) where T<:Real at In[23]:14
  ...

You seem to be calling a different model then the one defined. Could you post a MWE please.

1 Like

Thank you for your reply.
Sorry, I think I pasted something different.
I edited the first post.

You defined you distribution in a way that all parameters need to be of the same type.

During AD, the type of a in x[n] ~ Diffusion(a, 0.02, 0.03) will inevitably change. So you could do the following instead:

struct Diffusion{T1<:Real,T2<:Real,T3<:Real} <: ContinuousUnivariateDistribution
    D::T1
    δ::T2
    ϵ::T3
    Diffusion{T1,T2,T3}(D::T1, δ::T2, ϵ::T3) where {T1<:Real,T2<:Real,T3<:Real} = new{T1,T2,T3}(D, δ, ϵ)
end

and change the remaining codes accordingly. Also, the distribution doesn’t need to be mutable. Just use struct, this way Julia will be more happy. :wink:

1 Like

So these codes are different!
I didn’t know that :dizzy_face:

Diffusion{T1,T2,T3}(D::T1, δ::T2, ϵ::T3) where {T1<:Real,T2<:Real,T3<:Real}
Diffusion{T}(D::T, δ::T, ϵ::T) where {T:<Real}

The sampler started working!
I’ve learned a lot!
Thank you so much😊!!!