How to Use a Distribution Estimated via Kernel Density Estimation (KDE) as a Prior in Turing.jl

The simplest solution might be something like that proposed in Sampling from a KDE object - #4 by sethaxen. A KDE is just a mixture model, so if you keep a copy of the data and the bandwidth, you can convert it to a Distributions.MixtureModel.

This may not be ideal for you, since logpdf evaluation will scale with the size of the data that the KDE was fit to. You could have an alternative implementation that uses an interpolation scheme between the points the KDE is evaluated at. That might look something more like this.

using KernelDensity, Distributions, Interpolations, Random, Turing

struct InterpKDEDistribution{T<:Real,K<:KernelDensity.InterpKDE} <: ContinuousUnivariateDistribution
    kde::K
end
function InterpKDEDistribution(k::KernelDensity.InterpKDE)
    T = eltype(k.kde.x)
    return InterpKDEDistribution{T,typeof(k)}(k)
end
function InterpKDEDistribution(k::KernelDensity.UnivariateKDE)
    return InterpKDEDistribution(KernelDensity.InterpKDE(k))
end

function Distributions.minimum(d::InterpKDEDistribution)
    return first(only(Interpolations.bounds(d.kde.itp.itp)))
end

function Distributions.maximum(d::InterpKDEDistribution)
    return last(only(Interpolations.bounds(d.kde.itp.itp)))
end

function Distributions.pdf(d::InterpKDEDistribution, x::Real)
    return pdf(d.kde, x)
end

function Distributions.logpdf(d::InterpKDEDistribution, x::Real)
    return log(pdf(d, x))
end

# need to at least have a very rough implementation of this
# much better/more efficient implementation is possible,
# but this will only be called once when sampling starts 
function Random.rand(rng::Random.AbstractRNG, d::InterpKDEDistribution)
    (; kde) = d
    knots = Interpolations.knots(kde.itp.itp)
    cdf = cumsum(pdf.(Ref(kde), knots))
    u = rand(rng)
    return knots[findlast(u .> cdf)]
end

x = rand(Poisson(100), 100)
k = KernelDensity.kde(x)
d = InterpKDEDistribution(k)

@model function foo()
    x ~ d
end

model = foo()
chain = sample(model, NUTS(), 1_000)

Note that no custom Bijector is necessary, since the constraints on a univariate distribution are lower-bounded, upper-bounded, or both, and custom bijectors are already defined for these cases. i.e. defining minimum and maximum is sufficient here.

5 Likes