Truncated distributions in Gen

Is there a way to sample from a truncated distribution in Gen? In other words, what should go in ?? below?

using Gen, Distributions

@gen function line_model(xs::Vector{Float64})

    slope = @trace(??, :slope)
    intercept = @trace(normal(0, 2), :intercept)

    for (i, x) in enumerate(xs)
        @trace(normal(slope * x + intercept, 0.1), (:y, i))
    end

end;

xs = [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.];
trace = Gen.simulate(line_model, (xs,));

I’ve tried:

  1. Distributions.Truncated(normal(0, 1), 0, Inf). I believe this doesn’t work because Gen’s normal(0, 1) returns a float.
  2. I considered implementing this using @dist, but couldn’t formulate truncation as a deterministic transformation:
@dist function truncated_normal(mean, sd, lb, ub)
    max(lb, min(normal(mean, sd), ub))
end

Is there a way to implement Truncated from Distributions.jl for any distribution in Gen?

Thanks.

Found a solution using Genify.

Define:

function truncated_normal(mean, sd, lb, ub)
    d = Distributions.Truncated(Normal(mean, sd), lb, ub)
    x = rand(d)
end

and

gen_truncated_normal = genify(truncated_normal, Real, Real, Real, Real)

Then, modify the line in line_model to:

slope = @trace(gen_truncated_normal(0,1,0,Inf), :slope)

4 Likes

Awesome! For people finding this now, you can also use the new GitHub - probcomp/GenDistributions.jl: Use Distributions.jl distributions from within Gen package to call any Distributions.jl distribution from within a Gen model.

2 Likes

Thanks, thats great!

Sweet! Thanks so much, @Alex_Lew!

can we register this pkg

1 Like