Hello everyone.
I come from a general Machine Learning and Neural Networks background and learned about probabilistic programming and Bayesian Inference in the last year. Naturally I wanted to see what was available in the Julia ecosystem and was not disappointed.
But then, a thought crossed my mind, and I wanted to hear opinions from more experienced people.
Why do we actually need a Probabilistic Programming Language (PPL)? In Julia it seems as simply syntactic sugar, instead of specifying a distribution/sampler using tools from Distributions.jl.
The advantage I envision for a direct approach might be easier composition of packages, samplers etc, in line with Julia packages at large. Also, perhaps less overhead of the PPL/ less compilation time?
What would be the disadvantages? Inconvenience or is there something more? Also, what do different PPLs offer relative to one another?
Keen to read your thoughts,
Lior
========
To clarify, lets take an example from the Turing tutorial - Coin Flipping. The Turing model is
using Turing
@model function coinflip(; N::Int)
# Our prior belief about the probability of heads in a coin toss.
p ~ Beta(1, 1)
# Heads or tails of a coin are drawn from `N` independent and identically
# distributed Bernoulli distributions with success rate `p`.
y ~ filldist(Bernoulli(p), N)
return y
end;
Now, conditioning in Turing is
coinflip(y::AbstractVector{<:Real}) = coinflip(; N=length(y)) | (; y)
and sample with, i.e.
chain = sample(model, NUTS(), 2_000, progress=false);
An alternative that reuses Turing inference code would be (please excuse workarounds for product_distribution
)
using Distributions
prior = DirichletMultinomial(1, [1,1])
likelihood(x) = Multinomial(1, [x, 1-x])
coinflip_joint(x) = product_distribution([likelihood(x) for i in axes(data)[1]])
coinflip_logpdf(x) = logpdf(prior, x) +
logpdf(coinflip_joint(x),
stack(map(x -> [x, 1-x], data))
)
# + const
and we can reuse Turing inference libraries by defining it as a LogDensityProblem (didn’t completely understand the docs though)
# CoinFlipProblem <: LogDensityProblem ?
chain = sample(CoinFlipProblem, NUTS(), 2_000, progress=false);