Getting off the ground with Gen

I am really interested in Gen.jl, but am somewhat confused from the documentation. That is, I could not find any example of the functionality I’m trying to use. Can someone help me with the first steps?

My first question is, how to just get the below code to work?
My second question is, how do I implement custom proposals, say just a simple random walk proposal? I tried a couple of ways (for example using the update function), but none of them lead to any sensible inferences for my simple example.

The goal of this code is just to estimate the mean parameter of a normal distribution using MH.
For me the chain just gets stuck at the initial value far from the mean of the data.

using Gen, Plots, Random

@gen function model(n::Int)
    y = Vector{Float64}(undef, n)
    m = @trace(normal(0, 10.0), :m)
    for i in 1:n
        y[i] = @trace(normal(m, 1.0), (:y, i)) 

function do_inference(y, num_iters)
    trace, = generate(model, (length(y),), choicemap([((:y, i), y[i]) for i in 1:n]))
    ms = Float64[]
    for i=1:num_iters
        trace, = mh(trace, select(:m))
        @show trace[:m], get_score(trace)
        push!(ms, trace[:m])

n = 50
(trace, _) = Gen.generate(model, (n,))
m = trace[:m]
y = [trace[(:y, i)] for i in 1:n];
trace, chn = do_inference(y, 1000)

Re your first question: there was a bug in your code for constructing the choice map. It should be:

choicemap([((:y, i), y[i]) for i in 1:n]...)

I ran the code after making that fix, and it estimates the mean of the Gaussian as expected.

Re your section question, custom proposal distributions are implemented as generative functions. For a random walk on the mean in your model, you can use, e.g.:

@gen function random_walk(trace, stdev)
   m ~ normal(trace[:m], stdev)
   # note that ~ is a newish syntactic sugar in Gen; this line is equivalent to m = @trace(normal(trace[:m], stdev), :m

Then you can apply it using a different variant of mh that accepts the proposal instead of the selection, e.g.:

trace, = mh(trace, random_walk, (0.1,))

Hope that helps!

1 Like