Predict values using multinomial logistic regression?

Dear Julia community :slightly_smiling_face:,

I have been trying to run a Multinomial logistic regression and to get the estimated results from a certain set of values.

I managed to run the regression, but fail to get the estimated results for a certain set of covariate values.

In a model managed by the GLM package, I would have done:

# Packages:
begin 
    using GLM
    using Plots
    using DataFrames
end

# Data: 
begin 
    Y                   = Float64.(rand(Binomial(), 100))
    X1                  = rand(100)
    data                = DataFrame(Y = Y, X1 = X1)
    new_data            = DataFrame(X1 = rand(100))
end

# Regression:
begin 
    logistic_model      = GLM.glm(@formula(Y ~ X1), data, Bernoulli(), LogitLink())
    predicted_values    = GLM.predict(logistic_model, new_data)
    Plots.plot(X1, predicted_values)
end

Now, if I am trying to run a multinomial logistic regression, how could I manage to get results similar to those obtained through the `GLM.predict() function?

First, trying with Econometrics.jl:

# Package: 
begin 
    using Econometrics
end

# Data: 
begin 
    Y           = rand(1:5, 100)
    X1          = rand(1:5,100)
    data        = DataFrame(Y = Y, X1 = X1)
    new_data    = DataFrame(X1 = rand(1:5,100)) 
end

# Regression:
begin 
    multinomial_logistic        = Econometrics.fit(EconometricModel,
                                    @formula(Y ~ X1 ),
                                    data)
    Econometrics.predict(multinomial_logistic) # Yields the estimated vaues for the exact vector of values on which the regression was performed.
    # Econometrics.predict(multinomial_logistic, new_data) # Method error 
end

Trying with OrdinalMultinomialModels.jl, I run:

begin 
    using OrdinalMultinomialModels

    model = polr(@formula(Y ~ X1), data, LogitLink())

    # predict(model) # Method error

    #  OrdinalMultinomialModels.predict(model, new_data) # Method error too
    # predict function ?
end

My problem here is that I do not find a function similar to the GLM.predict() one.

When looking up different methods for predict functions, I get results for all the mentioned

methods(predict)

# 13 methods for generic function "predict" from StatsAPI:
  [1] predict(mm::StatsModels.TableRegressionModel{T, S}, data; kwargs...) where {T<:OrdinalMultinomialModel, S<:(Matrix)}
     @ OrdinalMultinomialModels ~/.julia/packages/OrdinalMultinomialModels/axgiL/src/ordmnfit.jl:183
  [2] predict(m::StatsModels.TableRegressionModel, new_x::AbstractMatrix; kwargs...)
     @ StatsModels ~/.julia/packages/StatsModels/mPD8T/src/statsmodel.jl:137
  [3] predict(mm::StatsModels.TableRegressionModel, data; kwargs...)
     @ StatsModels ~/.julia/packages/StatsModels/mPD8T/src/statsmodel.jl:172
  [4] predict(a::StatsModels.TableRegressionModel, args...; kwargs...)
     @ StatsModels ~/.julia/packages/StatsModels/mPD8T/src/statsmodel.jl:28
  [5] predict(mm::LinearModel, newx::AbstractMatrix, interval::Symbol, level::Real)
     @ GLM deprecated.jl:103
  [6] predict(mm::LinearModel, newx::AbstractMatrix; interval, level)
     @ GLM ~/.julia/packages/GLM/vM20T/src/lm.jl:250
  [7] predict(mm::LinearModel, newx::AbstractMatrix, interval::Symbol)
     @ GLM deprecated.jl:103
  [8] predict(mm::GLM.LinPredModel)
     @ GLM ~/.julia/packages/GLM/vM20T/src/linpred.jl:265
  [9] predict(mm::GLM.AbstractGLM, newX::AbstractMatrix; offset, interval, level, interval_method)
     @ GLM ~/.julia/packages/GLM/vM20T/src/glmfit.jl:650
 [10] predict(obj::EconometricModel{<:Econometrics.NominalResponse})
     @ Econometrics ~/.julia/packages/Econometrics/yAppe/src/statsbase.jl:176
 [11] predict(obj::EconometricModel{<:Econometrics.OrdinalResponse})
     @ Econometrics ~/.julia/packages/Econometrics/yAppe/src/statsbase.jl:178
 [12] predict(obj::EconometricModel)
     @ Econometrics ~/.julia/packages/Econometrics/yAppe/src/statsbase.jl:175
 [13] predict(m::OrdinalMultinomialModel, newX::Matrix{T}; kind) where T<:Union{Float32, Float64}
     @ OrdinalMultinomialModels ~/.julia/packages/OrdinalMultinomialModels/axgiL/src/ordmnfit.jl:144

Going to the source of OrdinalMultinomialModels.jl, I find that the following function exists:

predict(m::OrdinalMultinomialModel, newX::Matrix{T}; kind::Symbol=:class) 

But the type of my model when running model = polr(@formula(Y ~ X1), data, LogitLink()) is:

StatsModels.TableRegressionModel{OrdinalMultinomialModel{Int64, Float64, LogitLink}, Matrix{Float64}}.

I am quite sure I am missing the function I want, given the number of methods existing for predict() in both package. I am sorry if my question is not relevant. Having the equivalent of the GLM.predict() would be of great help.

Maybe I am missing something else?

Does anyone know if such function exists for multinomial logistic regression? Any hint would be greatly appreciated .

Thank you.

Might be good to add docs for predict to the packages, but from the test of OrdinalMultinomialModels.jl you can get:

julia> using OrdinalMultinomialModels, RDatasets

julia> housing = dataset("MASS", "housing");

julia> houseplr1 = polr(@formula(Sat ~ Infl + Type + Cont), housing,
                   LogitLink(), wts = housing[!, :Freq])
StatsModels.TableRegressionModel{OrdinalMultinomialModel{Int64, Float64, LogitLink}, Matrix{Float64}}

Sat ~ Infl + Type + Cont

Coefficients:
───────────────────────────────────────────────────────────────
                        Estimate  Std.Error   t value  Pr(>|t|)
───────────────────────────────────────────────────────────────
intercept Low|Medium   -0.496141  0.124541   -3.98376    0.0002
intercept Medium|High   0.690706  0.125212    5.51628    <1e-06
Infl: Medium            0.566392  0.104963    5.39611    <1e-05
Infl: High              1.28881   0.126705   10.1718     <1e-14
Type: Apartment        -0.572352  0.118747   -4.81991    <1e-05
Type: Atrium           -0.366182  0.156766   -2.33586    0.0226
Type: Terrace          -1.09101   0.151514   -7.20075    <1e-09
Cont: High              0.360284  0.0953574   3.77825    0.0003
───────────────────────────────────────────────────────────────

julia> predict(houseplr1, housing,kind=:probs)
72×3 DataFrame
 Row │ Low       Medium    High
     │ Float64?  Float64?  Float64?
─────┼──────────────────────────────
   1 │ 0.378448  0.287676  0.333876
   2 │ 0.378448  0.287676  0.333876
   3 │ 0.378448  0.287676  0.333876
(...)
1 Like