Here is my Stheno/AdvancedHMC version of the GP fit:
using AdvancedHMC
using LinearAlgebra, Statistics, Random, Distributions
using ParameterHandling
using ParameterHandling: value, flatten
using AbstractGPs
using Stheno
using Zygote
# container for observables, covariates, params for priors
struct StitchProblem{TY <: AbstractVector, TX <: AbstractVector, Tν}
x::TX # covariates
y::TY # observations
ν::Tν
end
function StitchProblem(x1,y1,x2,y2,ν)
x = BlockData(GPPPInput(:fbb1, x1), GPPPInput(:fbb2, x2))
y = vcat(y1, y2)
StitchProblem(x,y,ν)
end
function build_gp(θ)
return @gppp let
fbb1 = θ.σ * stretch(GP(SEKernel()), 1 / θ.ρ)
fbb2 = θ.a * fbb1 + θ.b
end
end
function build_obs_cov(problem, θ)
v = problem.ν.σ^2
ny = length(problem.y)
return Diagonal(fill(v, ny))
end
function pprior(problem)
function pp(θ)
ν = problem.ν
l_prior = (logpdf(ν.a_dist, θ.a)
+ logpdf(ν.b_dist, θ.b)
+ logpdf(ν.ρ_dist, θ.ρ)
+ logpdf(ν.σ_dist, θ.σ)
)
return l_prior
end
return pp
end
function nlml(θ, problem)
f = build_gp(θ)
C = build_obs_cov(problem, θ)
loss = -logpdf(f(problem.x, C), problem.y)
end
function build_model(n1=50, n2=30, σ_obs=0.02, l=1.0)
# θ0 = (σ = positive(1.0), ρ = positive(1.0), a = 1.0, b = 1.0)
θ0 = (σ = positive(1.0), ρ = positive(1.0), a = positive(1.0), b = positive(1.0))
# "actual" data
ftrue(x) = exp(-(x/l)^2)
x1 = randn(n1) .- 1.0
x2 = randn(n2) .+ 1.0
err1 = σ_obs * randn(n1)
err2 = σ_obs * randn(n2)
y1 = ftrue.(x1) + err1
y2 = 1.5*ftrue.(x2) .+ 0.5 + err2
# a_dist = Normal(0., 5.)
# b_dist = Normal(0., 5.)
a_dist = Gamma(5, 0.25)
b_dist = Gamma(5, 0.25)
ρ_dist = InverseGamma(5, 0.25)
σ_dist = LogNormal()
ν = (;a_dist, b_dist, ρ_dist, σ_dist, σ = σ_obs)
problem = StitchProblem(x1,y1,x2,y2,ν)
θ0_flat, unflatten = flatten(θ0)
unpack = value ∘ unflatten
pp = pprior(problem)
function logp(θflat)
θ = unpack(θflat)
return -nlml(θ, problem) + pp(θ)
end
function ∂logp(θflat)
lml, back = Zygote.pullback(logp, θflat)
∂θflat = first(back(1.0))
return lml, ∂θflat
end
# break here so we can diagnose code
return problem, logp, ∂logp, θ0_flat, unpack
end
function runhmc(logp, ∂logp, θ0_flat, n_samples=1000, n_adapts=100)
D = length(θ0_flat)
metric = DiagEuclideanMetric(D)
h = Hamiltonian(metric, logp, ∂logp)
initial_eps = find_good_stepsize(h, θ0_flat)
integrator = Leapfrog(initial_eps)
prop = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(0.8, integrator))
samples, stats = sample(h, prop, θ0_flat, n_samples, adaptor, n_adapts;
drop_warmup = true,
progress=true)
return samples, stats
end
problem, logp, ∂logp, θ0_flat, unpack = build_model()
samples, stats = runhmc(logp, ∂logp, θ0_flat)
It may require a patch to Stheno which is a pending PR.