Help with understanding Turing's predict function

Hi everybody,

I played around a bit with the predict() function in Turing and encountered some behaviour I don’t quite understand. When I use it to produce new output based on a chain, it only seems to take some of the variables in the original chain into account but not others.

Here is some example codes that reproduces the problem. In the code I simulate some data for a simple linear regression, and then fit that data using Turing. I then use the predict function to generate new data based on the fitted chain, and fit the predicted data once more using the same model.

using Turing, StatsPlots

@model function linear_reg(x, y)
    β ~ Normal(0, 1)
    σ ~ Exponential(0.1)

    for i ∈ eachindex(y)
        y[i] ~ Normal(β * x[i], σ)
    end
end;

pars = Array{Float64}(undef, 20, 2, 2)

for i = 1:20
    # simulate some data and fit a model to it
    f(x) = (i-20)/10 * x + i/100 * randn();
    Δ = 0.1; xs_train = 0:Δ:100; ys_train = f.(xs_train);
    m_train = linear_reg(xs_train, ys_train);
    chain_lin_reg = sample(m_train, NUTS(), 1000);

    # make predictions based on the fitted model
    m_test = linear_reg(xs_train, Vector{Union{Missing, Float64}}(undef, length(ys_train)));
    predictions = predict(m_test, chain_lin_reg)
    ys_predictions = vec(mean(Array(group(predictions, :y)); dims = 1))
    
    # fit a model to the predictions
    m_recover = linear_reg(xs_train, ys_predictions);
    chain_lin_reg_recover = sample(m_recover, NUTS(), 1000);

    # store the parameters
    pars[i, 1, :] = mean(Array(chain_lin_reg), dims = 1)
    pars[i, 2, :] = mean(Array(chain_lin_reg_recover), dims = 1)
end

p1 = scatter(pars[:,1,1], pars[:,2,1], legend=false, aspect_ratio=:equal);
title!("β")
p2 = scatter(pars[:,1,2], pars[:,2,2], legend=false, aspect_ratio=:equal);
title!("σ")
plot(p1, p2, layout = (1,2), legend=false)
xlabel!("original")
ylabel!("recovered")

prediciton_troubles

As can be seen from the plot this code produces, the slope parameter of the regression model seems to be used in generating data when using predict but the noise parameter is not used for some reason. Is there something I’ve done wrong or I’m missing here, and if not, is this the intended behaviour of predict? Thanks for your help!

you want like this ?

I would’ve expected sigma to also be applied to the output of predict. Here I’m doing it manually to show what I expected the result to look like:

using Turing, StatsPlots

@model function linear_reg(x, y)
    β ~ Normal(0, 1)
    σ ~ Exponential(0.1)

    for i ∈ eachindex(y)
        y[i] ~ Normal(β * x[i], σ)
    end
end;

pars = Array{Float64}(undef, 20, 2, 2)

for i = 1:10
    # simulate some data and fit a model to it
    f(x) = (i-5)/10 * x + i/100 * randn();
    Δ = 0.1; xs_train = 0:Δ:100; ys_train = f.(xs_train);
    m_train = linear_reg(xs_train, ys_train);
    chain_lin_reg = sample(m_train, NUTS(), 1000);

    # make predictions based on the fitted model
    m_test = linear_reg(xs_train, Vector{Union{Missing, Float64}}(undef, length(ys_train)));
    predictions = predict(m_test, chain_lin_reg)
    ys_predictions = vec(mean(Array(group(predictions, :y)); dims = 1))

    # add noise to the predictions
    σ = mean(Array(group(chain_lin_reg, :σ)); dims = 1)
    ys_predictions = rand.(Normal.(ys_predictions, σ))

    # fit a model to the predictions
    m_recover = linear_reg(xs_train, ys_predictions);
    chain_lin_reg_recover = sample(m_recover, NUTS(), 1000);

    # store the parameters
    pars[i, 1, :] = mean(Array(chain_lin_reg), dims = 1)
    pars[i, 2, :] = mean(Array(chain_lin_reg_recover), dims = 1)
end

p1 = scatter(pars[:,1,1], pars[:,2,1], legend=false, aspect_ratio=:equal);
title!("β")
p2 = scatter(pars[:,1,2], pars[:,2,2], legend=false, aspect_ratio=:equal);
title!("σ")
fig = plot(p1, p2, layout = (1,2), legend=false)
xlabel!("original")
ylabel!("recovered")
savefig(fig,"prediciton_troubles.png")

prediciton_troubles

The issue here is that you’re using the mean prediction in the first example:)

The following is your offending line:

    ys_predictions = vec(mean(Array(group(predictions, :y)); dims = 1))

If you actually inspect predictions, i.e. the chain from the line before, you can see that the predictions do indeed have the correct std. Here is the result for i = 10:

...
Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Missing 

        y[1]   -0.0013    0.0966    0.0030   1046.1171    868.7508    1.0004       missing
        y[2]   -0.0980    0.1018    0.0034    907.8375    845.5306    1.0008       missing
        y[3]   -0.1987    0.0985    0.0033    903.8513    784.7725    1.0002       missing
        y[4]   -0.2984    0.0970    0.0030    995.8047    982.7992    0.9993       missing
        y[5]   -0.4002    0.0989    0.0031   1046.4576    942.0177    0.9999       missing
        y[6]   -0.5024    0.0980    0.0032    966.1871    767.5986    1.0021       missing
        y[7]   -0.6039    0.0972    0.0031   1014.4943   1024.3146    0.9999       missing
        y[8]   -0.7007    0.1003    0.0031   1046.7164    942.2477    0.9990       missing
...

As you can see, the std. is correct.

But then in the above line you just average this out:) And so you’re really just trying to fit the mean regression, and hence you get ~0 variance in the posterior.

1 Like