The simplest solution might be something like that proposed in Sampling from a KDE object - #4 by sethaxen. A KDE is just a mixture model, so if you keep a copy of the data and the bandwidth, you can convert it to a Distributions.MixtureModel
.
This may not be ideal for you, since logpdf
evaluation will scale with the size of the data that the KDE was fit to. You could have an alternative implementation that uses an interpolation scheme between the points the KDE is evaluated at. That might look something more like this.
using KernelDensity, Distributions, Interpolations, Random, Turing
struct InterpKDEDistribution{T<:Real,K<:KernelDensity.InterpKDE} <: ContinuousUnivariateDistribution
kde::K
end
function InterpKDEDistribution(k::KernelDensity.InterpKDE)
T = eltype(k.kde.x)
return InterpKDEDistribution{T,typeof(k)}(k)
end
function InterpKDEDistribution(k::KernelDensity.UnivariateKDE)
return InterpKDEDistribution(KernelDensity.InterpKDE(k))
end
function Distributions.minimum(d::InterpKDEDistribution)
return first(only(Interpolations.bounds(d.kde.itp.itp)))
end
function Distributions.maximum(d::InterpKDEDistribution)
return last(only(Interpolations.bounds(d.kde.itp.itp)))
end
function Distributions.pdf(d::InterpKDEDistribution, x::Real)
return pdf(d.kde, x)
end
function Distributions.logpdf(d::InterpKDEDistribution, x::Real)
return log(pdf(d, x))
end
# need to at least have a very rough implementation of this
# much better/more efficient implementation is possible,
# but this will only be called once when sampling starts
function Random.rand(rng::Random.AbstractRNG, d::InterpKDEDistribution)
(; kde) = d
knots = Interpolations.knots(kde.itp.itp)
cdf = cumsum(pdf.(Ref(kde), knots))
u = rand(rng)
return knots[findlast(u .> cdf)]
end
x = rand(Poisson(100), 100)
k = KernelDensity.kde(x)
d = InterpKDEDistribution(k)
@model function foo()
x ~ d
end
model = foo()
chain = sample(model, NUTS(), 1_000)
Note that no custom Bijector is necessary, since the constraints on a univariate distribution are lower-bounded, upper-bounded, or both, and custom bijectors are already defined for these cases. i.e. defining minimum
and maximum
is sufficient here.