Logistic Regression Mixture Model

I’m trying to implement robust logistic regression fromchapter 10.5 of PML. The model is p(y|\mathbf{x}) = \pi Ber(y|0.5) + (1-\pi) Ber(y|\sigma(\mathbf{w}^\top\mathbf{x})) where \pi are the mixture weights drawn from a uniform distribution and the regression weights are drawn from a normal distribution with mean 0 and std 10.

The author has a pymc3 model here but trying to implement it with any other sampler but MH() blows up with numerical errors. Does it have to do with the logistic function?

using Turing, MCMCChains, StatsPlots, Random, StatsFuns
using RDatasets

iris = dataset("datasets", "iris")
select!(iris, :SepalLength, :Species)
filter!(row -> row[:Species] in ["setosa", "versicolor"], iris)
iris[!, :target] = iris[!, :Species] .!= "setosa"
select!(iris, :SepalLength, :target);

log_x = iris[!, :SepalLength]
log_y = iris[!, :target]
log_xoutliers = [4.2, 4.5, 4.0, 4.3, 4.2, 4.4]
log_youtliers = ones(length(log_xoutliers))

log_x_train = [log_x; log_xoutliers]
log_y_train = [log_y; log_youtliers];

@model function robust_logreg(x, y)
    N = length(x)
    alpha ~ Normal(0, 10)
    beta ~ Normal(0, 10)
    
    mu  = alpha .+ beta .* x
    theta = logistic.(mu)
    
    pie ~ filldist(Beta(1, 1), N) # probability of contamination of outliers
    p = pie .* 0.5 .+ (1 .- pie) .* theta
    
    y .~ Bernoulli.(p)
end

roblog_model = robust_logreg(log_x_train, log_x_train)
roblog_chain = sample(roblog_model, HMC(0.05, 10), 1000)

I think you have a typo in the line

roblog_model = robust_logreg(log_x_train, log_x_train)

you are passing in log_x_train twice which means your observed data are actually your covariates and that means if you try to evaluate the logdensity of the Bernoulli you will get -Inf. If I modify that line it runs without problems for me.

Also unless you are using HMC for a specific reason maybe consider potentially using NUTS(0.65) instead of HMC(0.05, 10) as your inference algorithm because as far as I know PyMC3 uses NUTS as a default algorithm.

1 Like

Oh no! Thanks for spotting that. I’m glad that it was just a simple mistake like that. Now that it’s working I will be using NUTS to sample.

1 Like