I’m working on a “Rational Speech Acts” (RSA) model. Essentially, RSA is a recursive process where you call distributions over states of the world or utterances made.
So, in this case I have the following setup:
function literalListener(utterance)
location ~ locationprior()
margin ~ marginprior()
# acceptance conditions:
# if the location is near/inside (within the margin, too) the region
# don't adjust the logprob, otherwise scale it infinitely
# ensure that `location` and `margin` are stacked onto the trace
end
then we have something like…
function pragmaticSpeaker(location, margin)
utterance ~ utteranceprior()
L0 ~ listeralListener(utterance)
# acceptance conditions:
# if the `L0.margin == margin`, accept
# if `dist(L0.location, location) < epsilon`, accept
# ensure that `utterance` is stacked onto the trace
end
and finally…
function pragmaticListener(utterance)
location ~ locationprior()
margin ~ marginprior()
S1 ~ pragmaticSpeaker(location, margin)
# acceptance conditions:
# if S1.utterance == utterance, accept
# ensure taht `location` and `margin` are stacked onto the trace
end
Given the locations
, margins
, and utterances
are all finite, (each one has size: 150,000, 3, and 11, respectively) – I believe this is something that could be enumerated, and in fairly reasonable time. (I think the bottlenecks I’m running into with this model in WebPPL stem from an inability to vectorize things like "checking dist(L0.location, location) < epsilon
. [This is true even if I use less stringent requirements like doing 5K rejection samples, instead of enumerating everything.])
So, this leads me to my questions:
-
Does Turing.jl support this kind of process? (I had a stripped down version of RSA working in Gen (it had less than 50
utterances
+states
)). But it required bending Gen in ways it doesn’t like. (I have to mix-in “inference machinery” within the modeling because of the acceptance conditions.) - Does Turing.jl support exact enumeration (or similar) over these kinds of [discrete] domains?
Ideally, I’d be able to extract the posteriors from each of listeralListener
, pragmaticSpeaker
, and pragmaticListener
.