I am really interested in Gen.jl, but am somewhat confused from the documentation. That is, I could not find any example of the functionality I’m trying to use. Can someone help me with the first steps?
My first question is, how to just get the below code to work?
My second question is, how do I implement custom proposals, say just a simple random walk proposal? I tried a couple of ways (for example using the update function), but none of them lead to any sensible inferences for my simple example.
The goal of this code is just to estimate the mean parameter of a normal distribution using MH.
For me the chain just gets stuck at the initial value far from the mean of the data.
using Gen, Plots, Random
@gen function model(n::Int)
y = Vector{Float64}(undef, n)
m = @trace(normal(0, 10.0), :m)
for i in 1:n
y[i] = @trace(normal(m, 1.0), (:y, i))
end
y
end
function do_inference(y, num_iters)
trace, = generate(model, (length(y),), choicemap([((:y, i), y[i]) for i in 1:n]))
ms = Float64[]
for i=1:num_iters
trace, = mh(trace, select(:m))
@show trace[:m], get_score(trace)
push!(ms, trace[:m])
end
ms
end
Random.seed!(1)
n = 50
(trace, _) = Gen.generate(model, (n,))
m = trace[:m]
y = [trace[(:y, i)] for i in 1:n];
trace, chn = do_inference(y, 1000)
plot(y)
hline!([m])
plot!(chn)