Is there a way to sample from a truncated distribution in Gen? In other words, what should go in ??
below?
using Gen, Distributions
@gen function line_model(xs::Vector{Float64})
slope = @trace(??, :slope)
intercept = @trace(normal(0, 2), :intercept)
for (i, x) in enumerate(xs)
@trace(normal(slope * x + intercept, 0.1), (:y, i))
end
end;
xs = [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.];
trace = Gen.simulate(line_model, (xs,));
I’ve tried:
-
Distributions.Truncated(normal(0, 1), 0, Inf)
. I believe this doesn’t work because Gen’snormal(0, 1)
returns a float. - I considered implementing this using
@dist
, but couldn’t formulate truncation as a deterministic transformation:
@dist function truncated_normal(mean, sd, lb, ub)
max(lb, min(normal(mean, sd), ub))
end
Is there a way to implement Truncated
from Distributions.jl for any distribution in Gen?
Thanks.