Variational Inference for Multinomial output

Hi there,

I want to use Variational Inference in Turing for Multinomial output. Here are the simulated data. There are 400 2-dimension observations in 4 groups. The purpose is to identify which group each observation belongs to.

Code for data simulation:

using Random
using Statistics
using Distributions
using StatsFuns: logistic
using LinearAlgebra

using Flux
using Turing

N = 400

μ₁ = [1,1]
μ₂ = [-1,1]
μ₃ = [-1,-1]
μ₄ = [1,-1]
C₁ = vcat(rand(MvNormal(μ₁, 1),100),Transpose(ones(100)))
C₂ = vcat(rand(MvNormal(μ₂, 1),100),Transpose(2*ones(100)))
C₃ = vcat(rand(MvNormal(μ₃, 1),100),Transpose(3*ones(100)))
C₄ = vcat(rand(MvNormal(μ₄, 1),100),Transpose(4*ones(100)))

X = hcat(C₁,C₂,C₃,C₄)
Y = zeros(4,N)
Y[1,1:100] .= 1
Y[2,101:200] .= 1
Y[3,201:300] .= 1
Y[4,301:400] .= 1

X = X[1:2,:]

Code for the model and inference:

function unpack(nn_params::AbstractVector)
    W₁ = reshape(nn_params[1:8], 4, 2);
    b₁ = reshape(nn_params[9:12], 4)

    W₂ = reshape(nn_params[13:28], 4, 4);
    b₂ = reshape(nn_params[29:32], 4)

    Wₒ = reshape(nn_params[33:48], 4, 4);
    bₒ = reshape(nn_params[49:52], 4)
    return W₁, b₁, W₂, b₂, Wₒ, bₒ
end

function nn_forward(xs, nn_params::AbstractVector)
    W₁, b₁, W₂, b₂, Wₒ, bₒ = unpack(nn_params)
    nn = Chain(Dense(W₁, b₁, σ),
               Dense(W₂, b₂, σ),
               Dense(Wₒ, bₒ),
               softmax)
    return nn(xs)
end;

# Create a regularization term and a Gaussain prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)

# Specify the probabalistic model.
@model bayes_nn(xs, ts) = begin
    # Create the weight and bias vector.
    nn_params ~ MvNormal(zeros(52), sig .* ones(52))

    # Calculate predictions for the inputs given the weights
    # and biases in theta.
    preds = nn_forward(xs, nn_params)

    # Observe each prediction.
    for i = 1:N
        ts[:,i] ~ Multinomial(1,preds[:,i])
    end
end;

advi = ADVI(10, 10_000)

ch_vi = vi(bayes_nn(X, Y),advi; optimizer = Flux.ADAM() )

Here is the error I got, "
MethodError: no method matching gep(::Ptr{ForwardDiff.Dual{ForwardDiff.Tag{Turing.Variational.var"#f#7"{ELBO,ADVI{Turing.Core.ForwardDiffAD{40}},Turing.Variational.MeanField{Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{Float64,1}},Turing.Model{Tuple{:nn_params},Tuple{:xs,:ts},var"##inner_function#599#55",NamedTuple{(:xs, :ts),Tuple{Array{Float64,2},Array{Float64,2}}},NamedTuple{(:xs, :ts),Tuple{Symbol,Symbol}}},Tuple{Int64}},Float64},Float64,10}}, ::Int64)"

Would you please let me know how to make it work?

Thanks,
Chuan