# Data stitching

In addition to gaussian processes, it’s also possible to fit the function f() using a basis set. I’ve found one such basis set that works well with Bayesian analysis is compact radial basis functions using the “bump function” kernel.

There are some practical advantages over gaussian processes, but the gaussian process is an excellent idea and can be useful in many other contexts, would love to see the Turing example!

A very simple thing I would do is take a bunch of intervals on the X axis, each containing at least one Y1 and Y2, take the mean of Y1 and Y2 in each interval, then solve the OLS problem between these means.

Ok, here we go:

using StatsPlots, Turing, LinearAlgebra,Random

Random.set_global_seed!(1)

f(x) = exp(-x^2)
x1 = randn(25) .- 1.0
x2 = randn(15) .+ 1.0
y1 = f.(x1)
y2 = 1.5*f.(x2) .+ 0.5 ## true values of a = 1.5 and b = 0.5

p1 = plot([x1, x2], [y1, y2]; seriestype=:scatter) # don't line up

display(p1)

using AbstractGPs

# Kernel with lengthscale ρ and scale σ
kernel(ρ, σ) = σ * Matern32Kernel() ∘ ScaleTransform(1 / ρ)

@model function stitchgp(xbase,ybase,xtform,ytform)
a ~ Gamma(5,1.0/4) # a is of order 1
b ~ Gamma(5,1.0/4) # b is of order 1

# The Gaussian process
ρ ~ InverseGamma(5, 1.0/4)  # Why inverse gamma: https://mc-stan.org/docs/stan-users-guide/fit-gp.html
σ ~ LogNormal()
f = GP(kernel(ρ, σ))
# Note: GP needs to see all data to smooth between xbase and xtform
xs = vcat(xbase, xtform)
fx ~ f(xs, 1e-6)  # Small jitter for numerical stability

fx1 = fx[1:length(xbase)]
fx2 = fx[length(xbase)+1:end]

# The obsevation model
σ_obs = 0.05  # Fixed as in original code
ybase ~ arraydist(Normal.(fx1, σ_obs))
ytform ~ arraydist(Normal.(a.*fx2 .+ b, σ_obs))
end

mgp = stitchgp(x1, y1, x2, y2)

sgp = sample(mgp, NUTS(300, 0.8), 100)

plot(sgp[:, [:a, :b, :ρ], 1])

Note that I have reduced the number of data points as well as samples as it is quite slow. GPs are always a bit tricky to apply on many data points, i.e., using some clever approximations, but this toy example should not be that slow … have not done a direct comparison with Stan, but we had successfully applied that to a couple thousand data points.

1 Like

I rewrote your GP version using Stheno, with Zygote AD, and AdvancedHMC as such rather than the Turing wrapper (just because I’ve been exploring those packages recently). Using default settings as shown in the documented examples (e.g. here, with extra stuff for priors), I get 1000 samples for 50+30 data points in 36 seconds. This is far faster than what I get with your code; I’d say high-level Turing is currently a questionable tool for this.

Although the linear algebra keeps the cores on my system somewhat busy for the Stheno code (unlike Turing+AbstractGPs), it only takes twice as long for twice as many points (77 seconds, 1000 samples, 100+60 data points) - nowhere near the N^3 asymptotic scaling, so there’s apparently still room for improvement by reducing allocations etc.

1 Like

Nice, it’s a pity that GPs are slow in Turing as the code quite closely follows the generative model. Had tried some small variants, e.g., transforming standard normal variates with the Cholesky factor of the kernel, but with little success, i.e., the straight-forward code was even a bit faster.
Will need to look into AbstractGPs and Stheno some more, thanks for the pointer to the relevant docs. In any case, would you mind sharing your code here?

Is this a general issue or was it the choice of AD backend? Try Zygote or Reverse diff?

Good point, had not used Turing in a while and was not aware that ForwardDiff is the default AD. Unfortunately, ReverseDiff seems even slower on the example code and Zygote does not stop compiling (killed it after 15mins) … don’t really know what’s going on.

Enable the reverse diff compiled cache

Here is my Stheno/AdvancedHMC version of the GP fit:

using LinearAlgebra, Statistics, Random, Distributions
using ParameterHandling
using ParameterHandling: value, flatten
using AbstractGPs
using Stheno
using Zygote

# container for observables, covariates, params for priors
struct StitchProblem{TY <: AbstractVector, TX <: AbstractVector, Tν}
x::TX # covariates
y::TY # observations
ν::Tν
end
function StitchProblem(x1,y1,x2,y2,ν)
x = BlockData(GPPPInput(:fbb1, x1), GPPPInput(:fbb2, x2))
y = vcat(y1, y2)
StitchProblem(x,y,ν)
end

function build_gp(θ)
return @gppp let
fbb1 = θ.σ * stretch(GP(SEKernel()), 1 / θ.ρ)
fbb2 = θ.a * fbb1 + θ.b
end
end
function build_obs_cov(problem, θ)
v = problem.ν.σ^2
ny = length(problem.y)
return Diagonal(fill(v, ny))
end

function pprior(problem)
function pp(θ)
ν = problem.ν
l_prior = (logpdf(ν.a_dist, θ.a)
+ logpdf(ν.b_dist, θ.b)
+ logpdf(ν.ρ_dist, θ.ρ)
+ logpdf(ν.σ_dist, θ.σ)
)
return l_prior
end
return pp
end

function nlml(θ, problem)
f = build_gp(θ)
C = build_obs_cov(problem, θ)
loss = -logpdf(f(problem.x, C), problem.y)
end

function build_model(n1=50, n2=30, σ_obs=0.02, l=1.0)

# θ0 = (σ = positive(1.0), ρ = positive(1.0), a = 1.0, b = 1.0)
θ0 = (σ = positive(1.0), ρ = positive(1.0), a = positive(1.0), b = positive(1.0))

# "actual" data
ftrue(x) = exp(-(x/l)^2)
x1 = randn(n1) .- 1.0
x2 = randn(n2) .+ 1.0
err1 = σ_obs * randn(n1)
err2 = σ_obs * randn(n2)
y1 = ftrue.(x1) + err1
y2 = 1.5*ftrue.(x2) .+ 0.5 + err2

# a_dist = Normal(0., 5.)
# b_dist = Normal(0., 5.)
a_dist = Gamma(5, 0.25)
b_dist = Gamma(5, 0.25)
ρ_dist = InverseGamma(5, 0.25)
σ_dist = LogNormal()
ν = (;a_dist, b_dist, ρ_dist, σ_dist, σ = σ_obs)

problem = StitchProblem(x1,y1,x2,y2,ν)
θ0_flat, unflatten = flatten(θ0)
unpack = value ∘ unflatten
pp = pprior(problem)

function logp(θflat)
θ = unpack(θflat)
return  -nlml(θ, problem) + pp(θ)
end
function ∂logp(θflat)
lml, back = Zygote.pullback(logp, θflat)
∂θflat = first(back(1.0))
return lml, ∂θflat
end
# break here so we can diagnose code
return problem, logp, ∂logp, θ0_flat, unpack
end

function runhmc(logp, ∂logp, θ0_flat, n_samples=1000, n_adapts=100)
D = length(θ0_flat)
metric = DiagEuclideanMetric(D)
h = Hamiltonian(metric, logp, ∂logp)
initial_eps = find_good_stepsize(h, θ0_flat)
integrator = Leapfrog(initial_eps)
prop = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)

drop_warmup = true,
progress=true)
return samples, stats
end

problem, logp, ∂logp, θ0_flat, unpack = build_model()
samples, stats = runhmc(logp, ∂logp, θ0_flat)

It may require a patch to Stheno which is a pending PR.

1 Like

Ok, seems a bit faster with Turing.setrdcache(true), but still slower than forwarddiff. Also the Tracker backend fails with a stack overflow in promote_type. There seems to be some problem with Turing here … Ideally, GP code should be dominated by the Cholesky decomposition of the kernel matrix.
@Ralph_Smith Thanks for the code, will give it a try. Unfortunately, it’s quite a bit longer than the Turing version.

Ok, all. Here’s actually a much simpler set of code which is also general purpose…

The bumpfun is a so called “compact radial basis function” which is infinitely smooth and defined on [-1,1]. There are some theorems that radial basis functions with evenly spaced centers are dense in the space of continuous functions on a compact interval as the number of centers increases. So basically, if you put enough centers in you can approximate any continuous function well enough.

The advantage of compactly supported radial basis functions is that they reduce the correlation between coefficients (there is zero correlation between any two basis functions if they are sufficiently far apart) and make sampling easier than global bases such as Fourier or non-compact radial basis functions.

bumpfun(x) = x > -one(x) && x < one(x) ? exp(one(x) - one(x)/(one(x) - x^2)) : zero(x)

rbf(coefs,centers,scale,x) = sum(coef * bumpfun((x-c)/scale) for (coef,c) in zip(coefs,centers))

@model function stitchdata2(xbase,ybase,xtform,ytform,centers,s)
err ~ Gamma(3.0,0.1/2.0)
a ~ Normal(0.0,10.0)
b ~ Normal(0.0,10.0)
coefs ~ MvNormal(repeat([0.0],length(centers)),50.0^2*I(length(centers)))
y1pred = rbf.(Ref(coefs),Ref(centers),s,xbase)
y2pred = rbf.(Ref(coefs),Ref(centers),s,xtform)
ybase ~ MvNormal(y1pred,err^2*I(length(ybase)))
ytform ~ MvNormal(a.*y2pred .+ b,err^2*I(length(ytform)))
end

mod2 = stitchdata2(x1,y1,x2,y2,collect(-5:5),2.5)

s2 = sample(mod2,NUTS(400,0.8),200)

plot(s2[:,[:a,:b],1])
(meana , meanb) = (mean(s2[:,:a,1]),mean(s2[:,:b,1]))

scatter(x1,y1)
scatter!(x2,(y2 .- meanb)./meana)

Sampling took ~ 40 seconds on my computer using the ForwardDiff default.