Data stitching

Ok, here’s my example. The Turing sampling took 2-3 seconds.


using StatsPlots, Turing, LinearAlgebra,Random

Random.set_global_seed!(1)

f(x) = exp(-x^2)
x1 = randn(50) .- 1.0
x2 = randn(30) .+ 1.0
y1 = f.(x1)
y2 = 1.5*f.(x2) .+ 0.5 ## true values of a = 1.5 and b = 0.5

p1 = plot([x1, x2], [y1, y2]; seriestype=:scatter) # don't line up


display(p1)

## how to determine a,b if we don't know f()?

@model function stitchdata(xbase,ybase,xtform,ytform)
    a ~ Gamma(5,1.0/4) # a is of order 1
    b ~ Gamma(5,1.0/4) # b is of order 1
    m ~ MvNormal(repeat([0.0],4),100.0^2*I(4))
    c ~ MvNormal(repeat([0.0],4),100.0^2*I(4))
    xbneighs = Vector{eltype(xbase)}[]
    ybneighs = Vector{eltype(ybase)}[]
    xtneighs = Vector{eltype(xbase)}[]
    ytneighs = Vector{eltype(ybase)}[]
    centers = (-1.0,-0.5,.5,1.0)
    for (i,dx) in enumerate(centers)
        f = x -> x > dx - .25 && x < dx + 0.25
        xneighindex = findall(f, xbase)
        #@show xneighindex
        push!(xbneighs, xbase[xneighindex])
        push!(ybneighs, ybase[xneighindex])
        tneighindex = findall(f,xtform)
        push!(xtneighs, xtform[tneighindex])
        push!(ytneighs, ytform[tneighindex])
    end
    for (i,(x,y)) in enumerate(zip(xbneighs,ybneighs))
        Turing.@addlogprob!(logpdf(MvNormal(m[i].*(x .- centers[i]) .+ c[i], 0.05^2*I(length(y))), y))
    end
    for (i,(x,y)) in enumerate(zip(xtneighs,ytneighs))
        Turing.@addlogprob!(logpdf(MvNormal(a.*(m[i].*(x .- centers[i]) .+ c[i]) .+ b, 0.05^2*I(length(y))), y))
    end
end

mm = stitchdata(x1,y1,x2,y2)

s = sample(mm,NUTS(400,0.8),200)

plot(s[:,[:a,:b],1])

This produces the data:

And the inference for a and b are:

I just ad-hoc decided on the “neighborhoods” to use.

I had to use the @addlogprob! macro because I’m subsetting the data and Turing is confused about that if I don’t.

Also, it might make more sense to actually have some measurement noise and/or fit the scale of the measurement noise…

when I added Normal(0.0,0.05) noise to the data, I get:

and

6 Likes