Ok, here’s my example. The Turing sampling took 2-3 seconds.
using StatsPlots, Turing, LinearAlgebra,Random
Random.set_global_seed!(1)
f(x) = exp(-x^2)
x1 = randn(50) .- 1.0
x2 = randn(30) .+ 1.0
y1 = f.(x1)
y2 = 1.5*f.(x2) .+ 0.5 ## true values of a = 1.5 and b = 0.5
p1 = plot([x1, x2], [y1, y2]; seriestype=:scatter) # don't line up
display(p1)
## how to determine a,b if we don't know f()?
@model function stitchdata(xbase,ybase,xtform,ytform)
a ~ Gamma(5,1.0/4) # a is of order 1
b ~ Gamma(5,1.0/4) # b is of order 1
m ~ MvNormal(repeat([0.0],4),100.0^2*I(4))
c ~ MvNormal(repeat([0.0],4),100.0^2*I(4))
xbneighs = Vector{eltype(xbase)}[]
ybneighs = Vector{eltype(ybase)}[]
xtneighs = Vector{eltype(xbase)}[]
ytneighs = Vector{eltype(ybase)}[]
centers = (-1.0,-0.5,.5,1.0)
for (i,dx) in enumerate(centers)
f = x -> x > dx - .25 && x < dx + 0.25
xneighindex = findall(f, xbase)
#@show xneighindex
push!(xbneighs, xbase[xneighindex])
push!(ybneighs, ybase[xneighindex])
tneighindex = findall(f,xtform)
push!(xtneighs, xtform[tneighindex])
push!(ytneighs, ytform[tneighindex])
end
for (i,(x,y)) in enumerate(zip(xbneighs,ybneighs))
Turing.@addlogprob!(logpdf(MvNormal(m[i].*(x .- centers[i]) .+ c[i], 0.05^2*I(length(y))), y))
end
for (i,(x,y)) in enumerate(zip(xtneighs,ytneighs))
Turing.@addlogprob!(logpdf(MvNormal(a.*(m[i].*(x .- centers[i]) .+ c[i]) .+ b, 0.05^2*I(length(y))), y))
end
end
mm = stitchdata(x1,y1,x2,y2)
s = sample(mm,NUTS(400,0.8),200)
plot(s[:,[:a,:b],1])
This produces the data:
And the inference for a and b are:
I just ad-hoc decided on the “neighborhoods” to use.
I had to use the @addlogprob!
macro because I’m subsetting the data and Turing is confused about that if I don’t.
Also, it might make more sense to actually have some measurement noise and/or fit the scale of the measurement noise…
when I added Normal(0.0,0.05) noise to the data, I get:
and