Hi everybody,
I played around a bit with the predict()
function in Turing and encountered some behaviour I don’t quite understand. When I use it to produce new output based on a chain, it only seems to take some of the variables in the original chain into account but not others.
Here is some example codes that reproduces the problem. In the code I simulate some data for a simple linear regression, and then fit that data using Turing. I then use the predict
function to generate new data based on the fitted chain, and fit the predicted data once more using the same model.
using Turing, StatsPlots
@model function linear_reg(x, y)
β ~ Normal(0, 1)
σ ~ Exponential(0.1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;
pars = Array{Float64}(undef, 20, 2, 2)
for i = 1:20
# simulate some data and fit a model to it
f(x) = (i-20)/10 * x + i/100 * randn();
Δ = 0.1; xs_train = 0:Δ:100; ys_train = f.(xs_train);
m_train = linear_reg(xs_train, ys_train);
chain_lin_reg = sample(m_train, NUTS(), 1000);
# make predictions based on the fitted model
m_test = linear_reg(xs_train, Vector{Union{Missing, Float64}}(undef, length(ys_train)));
predictions = predict(m_test, chain_lin_reg)
ys_predictions = vec(mean(Array(group(predictions, :y)); dims = 1))
# fit a model to the predictions
m_recover = linear_reg(xs_train, ys_predictions);
chain_lin_reg_recover = sample(m_recover, NUTS(), 1000);
# store the parameters
pars[i, 1, :] = mean(Array(chain_lin_reg), dims = 1)
pars[i, 2, :] = mean(Array(chain_lin_reg_recover), dims = 1)
end
p1 = scatter(pars[:,1,1], pars[:,2,1], legend=false, aspect_ratio=:equal);
title!("β")
p2 = scatter(pars[:,1,2], pars[:,2,2], legend=false, aspect_ratio=:equal);
title!("σ")
plot(p1, p2, layout = (1,2), legend=false)
xlabel!("original")
ylabel!("recovered")
As can be seen from the plot this code produces, the slope parameter of the regression model seems to be used in generating data when using predict
but the noise parameter is not used for some reason. Is there something I’ve done wrong or I’m missing here, and if not, is this the intended behaviour of predict
? Thanks for your help!