Calculate joint probability from Turing model

Given a model defined in Turing I would like to calculate the joint probability function over all the variables. Say I have the model

@model function model(y)
    mu ~ Normal(0, 1)
    y ~ Normal(mu, 1)
end

y_obs = 1
mu = 1

With the prob macro I can do

prob"mu = mu, y = y_obs | model = model"

and it gives me exactly what I want. However, now I would like to be able to not have to explicitly pass in the value of y_obs and instead do

m = model(y_obs)
prob"mu = mu | model = m"

Is that somehow possible?

My end goal is to create a function which only takes as input an instantiated model and returns the joint probability function. So I want to have a function get_joint_prob such that:

m = model(y_obs)
joint_prob = get_joint_prob(m)
joint_prob((mu=mu,)) == prob"mu = mu, y = y_obs | model = model"

(joint_prob takes in a named tuple so that it can deal with multiple variables if necessary).

I have already read through the documentation on the Turing side and looked a bit at the implementation of the prob macro and the VarInfo stuff but I didn’t quite understand all of it well enough to get something to work. Any pointers would be really helpful!

After having a look at the MH sampler implementation. I came up with the following solution:

function make_log_joint_density(model)
    return function joint_density(xval)
        vi = Turing.VarInfo(model)
        vi[@varname(x)] = [xval]
        model(vi)
        return Turing.getlogp(vi)
    end
end

This works for my small toy model above. Using the set_namedtuple!(vi::VarInfo, nt::NamedTuple) function I can generalise this to:

function make_log_joint_density(model)
    return function joint_density(named_tuple)
        vi = Turing.VarInfo(model)
        set_namedtuple!(vi, named_tuple)
        model(vi)
        return Turing.getlogp(vi)
    end
end

Does this sound like a reasonable approach or is there a better way to do this? Is there somewhere some documentation about the VarInfo type and how to interact with it?

I found the docs about the Turing compiler design and they answered most of my questions.

1 Like

I would look at the implementation of the macro. These functions get called in your case and can be customized for your use case
https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/prob_macro.jl#L24 https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/prob_macro.jl#L118.

1 Like