Turing support for exact inference (over discrete supports)

I’m working on a “Rational Speech Acts” (RSA) model. Essentially, RSA is a recursive process where you call distributions over states of the world or utterances made.

So, in this case I have the following setup:

function literalListener(utterance)
    location ~ locationprior()
    margin ~ marginprior()
    
    # acceptance conditions:
    #   if the location is near/inside (within the margin, too) the region
    #   don't adjust the logprob, otherwise scale it infinitely 

    # ensure that `location` and `margin` are stacked onto the trace
end

then we have something like…

function pragmaticSpeaker(location, margin)
    utterance ~ utteranceprior()
    L0 ~ listeralListener(utterance)
   
    # acceptance conditions:
    #   if the `L0.margin == margin`, accept
    #   if `dist(L0.location, location) < epsilon`, accept

    # ensure that `utterance` is stacked onto the trace
end

and finally…

function pragmaticListener(utterance)
    location ~ locationprior()
    margin ~ marginprior()
    
    S1 ~ pragmaticSpeaker(location, margin)

    # acceptance conditions:
    #   if S1.utterance == utterance, accept
   
   # ensure taht `location` and `margin` are stacked onto the trace
end

Given the locations, margins, and utterances are all finite, (each one has size: 150,000, 3, and 11, respectively) – I believe this is something that could be enumerated, and in fairly reasonable time. (I think the bottlenecks I’m running into with this model in WebPPL stem from an inability to vectorize things like "checking dist(L0.location, location) < epsilon. [This is true even if I use less stringent requirements like doing 5K rejection samples, instead of enumerating everything.])

So, this leads me to my questions:

  1. Does Turing.jl support this kind of process? (I had a stripped down version of RSA working in Gen (it had less than 50 utterances + states)). But it required bending Gen in ways it doesn’t like. (I have to mix-in “inference machinery” within the modeling because of the acceptance conditions.)
  2. Does Turing.jl support exact enumeration (or similar) over these kinds of [discrete] domains?

Ideally, I’d be able to extract the posteriors from each of listeralListener, pragmaticSpeaker, and pragmaticListener.

1 Like

I was also hoping to find a way to do exact inference over discrete supports in Turing.jl. It looks like it might be possible to implement something like this Pyro poutine with Turing-contexts.

I just stated playing with Turing.jl a couple days ago so I could be totally wrong.

But, since it’s been a couple years since this was asked, I wanted to bump the thread in case someone came up with a great solution, or had a suggestion about where to start.

EDIT: Whoops! Posted what was supposed to be a reply to another thread here :grimacing: A sec and I’ll give a response to the actual question here.

EDIT 2:

The issue with Turing.jl is that we don’t have access to the DAG, so automating something like this is unfortunately not doable (at the moment).

So the most straight-forward way to achieve this through manual labour is to define a Distribution which takes in your continuous parameters and performs the marginalization in the logpdf computation.

Alternatively, you can just perform the enumeration right in the model and use @addlogprob! to accumulate the probabilities.