The simplest solution might be something like that proposed in Sampling from a KDE object - #4 by sethaxen. A KDE is just a mixture model, so if you keep a copy of the data and the bandwidth, you can convert it to a Distributions.MixtureModel
This may not be ideal for you, since logpdf
evaluation will scale with the size of the data that the KDE was fit to. You could have an alternative implementation that uses an interpolation scheme between the points the KDE is evaluated at. That might look something more like this.
using KernelDensity, Distributions, Interpolations, Random, Turing
struct InterpKDEDistribution{T<:Real,K<:KernelDensity.InterpKDE} <: ContinuousUnivariateDistribution
function InterpKDEDistribution(k::KernelDensity.InterpKDE)
T = eltype(k.kde.x)
return InterpKDEDistribution{T,typeof(k)}(k)
function InterpKDEDistribution(k::KernelDensity.UnivariateKDE)
return InterpKDEDistribution(KernelDensity.InterpKDE(k))
function Distributions.minimum(d::InterpKDEDistribution)
return first(only(Interpolations.bounds(d.kde.itp.itp)))
function Distributions.maximum(d::InterpKDEDistribution)
return last(only(Interpolations.bounds(d.kde.itp.itp)))
function Distributions.pdf(d::InterpKDEDistribution, x::Real)
return pdf(d.kde, x)
function Distributions.logpdf(d::InterpKDEDistribution, x::Real)
return log(pdf(d, x))
# need to at least have a very rough implementation of this
# much better/more efficient implementation is possible,
# but this will only be called once when sampling starts
function Random.rand(rng::Random.AbstractRNG, d::InterpKDEDistribution)
(; kde) = d
knots = Interpolations.knots(kde.itp.itp)
cdf = cumsum(pdf.(Ref(kde), knots))
u = rand(rng)
return knots[findlast(u .> cdf)]
x = rand(Poisson(100), 100)
k = KernelDensity.kde(x)
d = InterpKDEDistribution(k)
@model function foo()
x ~ d
model = foo()
chain = sample(model, NUTS(), 1_000)
Note that no custom Bijector is necessary, since the constraints on a univariate distribution are lower-bounded, upper-bounded, or both, and custom bijectors are already defined for these cases. i.e. defining minimum
and maximum
is sufficient here.