I recently noticed that making an arraydist
prior composed of different types of distributions leads to type instability (which seems a bit obvious, in hindsight). Playing around, I thought making a union of the component types might help.
It does seem to, but I’m not sure how to make the approach easy to use in an actual model.
Below, I have three versions of the same model. Model 1 makes a heterogenous array of distributions, and is type-unstable. Model 2 constructs an array with elements that are a union of the types of the component distributions. It’s type stable, but obviously a pain to make. So, I tried to make a function that will make it easier to instantiate an array of the right type (see function arraydist_type
below).
using Turing
# (1) regular model
@model mwe1() = begin
w ~ arraydist([truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
end
# (2) Direct construction of Array{Union}
@model mwe2() = begin
w ~ arraydist(Union{Truncated{Normal{Float64}, Continuous, Float64}, Normal{Float64}}[truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
end
# Attempt to make an arraydist using type unions to fix type stability issue
function arraydist_type(y)
dists = Union{map(x->typeof(x), y)...}[y[i] for i in eachindex(y)]
return Product(dists)
end
# (3) Array_dist function to simplify the method in (2) above.
@model mwe3() = begin
w ~ arraydist_type(a)
end
# Check first method
test1 = mwe1()
@code_warntype test1.f(
Random.GLOBAL_RNG,
test1,
Turing.VarInfo(test1),
Turing.SampleFromPrior(),
Turing.DefaultContext(),
test1.args...,
)
# Check second method
test1 = mwe2()
@code_warntype test1.f(
Random.GLOBAL_RNG,
test1,
Turing.VarInfo(test1),
Turing.SampleFromPrior(),
Turing.DefaultContext(),
test1.args...,
)
# Check third method
test1 = mwe3()
@code_warntype test1.f(
Random.GLOBAL_RNG,
test1,
Turing.VarInfo(test1),
Turing.SampleFromPrior(),
Turing.DefaultContext(),
test1.args...,
)
As far as I can tell, the approaches used inside Models 2 and 3 look like they result in exactly the same kind of object:
a = arraydist(Union{Truncated{Normal{Float64}, Continuous, Float64}, Normal{Float64}}[truncated(Normal(0, 10), 0, Inf), Normal(1, 10)])
b = arraydist_type([truncated(Normal(0, 10), 0, Inf), Normal(0, 10)])
# a and b *look* exactly the same, and typeof(a) == typeof(b), but one is type stable inside the Turing model, and the other one isn't.
However, Model 3 is still type-unstable.
Anyone have suggestions for ways to approach this kind of problem? My arraydist_type
function doesn’t work, but I can’t really identify the issue. FWIW, my intended application involves an MxN matrix of heterogenous distributions, not just a single vector.