# Variational Inference for Multinomial output

Hi there,

I want to use Variational Inference in Turing for Multinomial output. Here are the simulated data. There are 400 2-dimension observations in 4 groups. The purpose is to identify which group each observation belongs to.

Code for data simulation:

``````using Random
using Statistics
using Distributions
using StatsFuns: logistic
using LinearAlgebra

using Flux
using Turing

N = 400

μ₁ = [1,1]
μ₂ = [-1,1]
μ₃ = [-1,-1]
μ₄ = [1,-1]
C₁ = vcat(rand(MvNormal(μ₁, 1),100),Transpose(ones(100)))
C₂ = vcat(rand(MvNormal(μ₂, 1),100),Transpose(2*ones(100)))
C₃ = vcat(rand(MvNormal(μ₃, 1),100),Transpose(3*ones(100)))
C₄ = vcat(rand(MvNormal(μ₄, 1),100),Transpose(4*ones(100)))

X = hcat(C₁,C₂,C₃,C₄)
Y = zeros(4,N)
Y[1,1:100] .= 1
Y[2,101:200] .= 1
Y[3,201:300] .= 1
Y[4,301:400] .= 1

X = X[1:2,:]
``````

Code for the model and inference:

``````function unpack(nn_params::AbstractVector)
W₁ = reshape(nn_params[1:8], 4, 2);
b₁ = reshape(nn_params[9:12], 4)

W₂ = reshape(nn_params[13:28], 4, 4);
b₂ = reshape(nn_params[29:32], 4)

Wₒ = reshape(nn_params[33:48], 4, 4);
bₒ = reshape(nn_params[49:52], 4)
return W₁, b₁, W₂, b₂, Wₒ, bₒ
end

function nn_forward(xs, nn_params::AbstractVector)
W₁, b₁, W₂, b₂, Wₒ, bₒ = unpack(nn_params)
nn = Chain(Dense(W₁, b₁, σ),
Dense(W₂, b₂, σ),
Dense(Wₒ, bₒ),
softmax)
return nn(xs)
end;

# Create a regularization term and a Gaussain prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)

# Specify the probabalistic model.
@model bayes_nn(xs, ts) = begin
# Create the weight and bias vector.
nn_params ~ MvNormal(zeros(52), sig .* ones(52))

# Calculate predictions for the inputs given the weights
# and biases in theta.
preds = nn_forward(xs, nn_params)

# Observe each prediction.
for i = 1:N
ts[:,i] ~ Multinomial(1,preds[:,i])
end
end;