Making a more type-stable arraydist for Turing?

I recently noticed that making an arraydist prior composed of different types of distributions leads to type instability (which seems a bit obvious, in hindsight). Playing around, I thought making a union of the component types might help.

It does seem to, but I’m not sure how to make the approach easy to use in an actual model.

Below, I have three versions of the same model. Model 1 makes a heterogenous array of distributions, and is type-unstable. Model 2 constructs an array with elements that are a union of the types of the component distributions. It’s type stable, but obviously a pain to make. So, I tried to make a function that will make it easier to instantiate an array of the right type (see function arraydist_type below).

using Turing

# (1) regular model
@model mwe1() = begin
    w ~ arraydist([truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
end


# (2) Direct construction of Array{Union}
@model mwe2() = begin
    w ~ arraydist(Union{Truncated{Normal{Float64}, Continuous, Float64}, Normal{Float64}}[truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
end


# Attempt to make an arraydist using type unions to fix type stability issue
function arraydist_type(y)
    dists = Union{map(x->typeof(x), y)...}[y[i] for i in eachindex(y)]
    return Product(dists)
end

# (3) Array_dist function to simplify the method in (2) above.
@model mwe3() = begin
    w ~ arraydist_type(a)
end


# Check first method
test1 = mwe1()

@code_warntype test1.f(
    Random.GLOBAL_RNG,
    test1,
    Turing.VarInfo(test1),
    Turing.SampleFromPrior(),
    Turing.DefaultContext(),
    test1.args...,
)


# Check second method
test1 = mwe2()

@code_warntype test1.f(
    Random.GLOBAL_RNG,
    test1,
    Turing.VarInfo(test1),
    Turing.SampleFromPrior(),
    Turing.DefaultContext(),
    test1.args...,
)


# Check third method
test1 = mwe3()

@code_warntype test1.f(
    Random.GLOBAL_RNG,
    test1,
    Turing.VarInfo(test1),
    Turing.SampleFromPrior(),
    Turing.DefaultContext(),
    test1.args...,
)

As far as I can tell, the approaches used inside Models 2 and 3 look like they result in exactly the same kind of object:

a = arraydist(Union{Truncated{Normal{Float64}, Continuous, Float64}, Normal{Float64}}[truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
b = arraydist_type([truncated(Normal(0, 10), 0, Inf), Normal(0, 10)])

# a and b *look* exactly the same, and typeof(a) == typeof(b), but one is type stable inside the Turing model, and the other one isn't.

However, Model 3 is still type-unstable.

Anyone have suggestions for ways to approach this kind of problem? My arraydist_type function doesn’t work, but I can’t really identify the issue. FWIW, my intended application involves an MxN matrix of heterogenous distributions, not just a single vector.

You might try making all entries truncated, even if the bounds are infinite. I’m not at my computer, but I’d guess that night work

1 Like

Thanks!! That’s a fantastic idea, and works.

(I was starting to go down the rabbit hole of making an array of union-types. That seems to work for the general case, but I’m not familiar enough with making parametric structs, and was getting stuck on some subtyping issues…)

Just to be complete, and in case anyone runs into a similar problem, it turns out my original (full) code had two issues: one was the combination of heterogenous distributions into an array. The other issue was how I was constructing the array.

In my full code, I was constructing the array of distributions using a comprehension based on a ternary operator, doing something like this:

w ~ arraydist([a[i] ? truncated(Normal(0, s[i], -Inf, Inf) : truncated(Normal(0, s[i], 0, Inf) for i in eachindex(a)])

which was producing a type-unstable model, even after the suggestion above to make everything truncated. I fixed it by putting in a function barrier:

function mixed_dist(a, s)
    M = [a[i] ? truncated(Normal(0, s[i], -Inf, Inf) : truncated(Normal(0, s[i], 0, Inf) for i in eachindex(a)]
    return M
end

...

w ~ arraydist(mixed_dist(a, s))

I admit I don’t quite get why a function barrier was needed here, but that, plus @cscherrer’s excellent suggestion above did the trick.

1 Like

Great! I’m glad it worked out :slight_smile:

I think the promotion rules could be improved so this doesn’t come up. For example,

julia> [1, 2, 3.0, 4]
4-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0

But in Distributions.jl,

julia> [Normal(), truncated(Normal(0,10),0,Inf)]
2-element Vector{Distribution{Univariate, Continuous}}:
 Normal{Float64}(μ=0.0, σ=1.0)
 Truncated(Normal{Float64}(μ=0.0, σ=10.0), range=(0.0, Inf))

Analogously to the Int/Float64 case, this could instead return

julia> [truncated(Normal(),-Inf, Inf), truncated(Normal(0,10),0,Inf)]
2-element Vector{Truncated{Normal{Float64}, Continuous, Float64}}:
 Truncated(Normal{Float64}(μ=0.0, σ=1.0), range=(-Inf, Inf))
 Truncated(Normal{Float64}(μ=0.0, σ=10.0), range=(0.0, Inf))
1 Like