Hello,
I am struggling to compute Deviance Information Criterion (DIC) in Julia as Turing/MCMCDiagnosticTools do not have an implementation of DIC as in R dic.samples(mod2, n.iter=1e3)
.
Backdrop: I am enrolled in a course on Bayesian Statistics on Coursera where these models are created in R. I am new to Julia and found that Turing.jl has a very expressive syntax (chef’s kiss), I decided to the coursework in Julia. In assignments and quizzers, many times I am required to compute DIC, so I thought of implementing it on my own but my results are not close to what I get from R’s dic.samples(mod, n.iter=1e3
function.
Here is my implementation for a simpler case, any help to figure out what I am missing here would be extremely appreciated:
using RDatasets, DataFrames
using Turing
using CategoricalArrays
using Distributions
using Statistics
# Load data
df = RDatasets.dataset("datasets", "warpbreaks")
println(first(df, 5))
# Create categorical variable and log-transform
df.cat = levelcode.(categorical(string.(df.Tension, "-", df.Wool)))
df.logBreaks = log.(df.Breaks)
# ANOVA model
@model function anova_t(X, y)
num_cat = length(unique(X.cat))
num_obs = length(y)
μ ~ filldist(Normal(0.0, 1e3), num_cat)
σ² ~ filldist(InverseGamma(0.5, 0.5), num_cat)
for i in 1:num_obs
y[i] ~ Normal(μ[X.cat[i]], sqrt(σ²[X.cat[i]]))
end
end
# Fit model
model_1 = anova_t(df[:, [:cat]], df.logBreaks)
chains = sample(model_1, NUTS(), 5000)
# Log-likelihood function
function log_likelihood(dat, params)
μ = params[iter=1, var=[Symbol("μ[$i]") for i in 1:length(unique(dat.cat))]]
σ = sqrt.(params[iter=1, var=[Symbol("σ²[$i]") for i in 1:length(unique(dat.cat))]])
sum(logpdf.(Normal.(μ.data[dat.cat], σ[dat.cat]), dat.logBreaks))
end
# Deviance calculations
mean_deviance = -2 * mean([log_likelihood(df, chains.value[iter=[i], var=:, chain=1]) for i in 1:size(chains, 1)])
deviance_at_mean = -2 * log_likelihood(df, mean(chains.value[iter=:, var=:, chain=1], dims=1))
println("mean_deviance: $mean_deviance") # prints mean_deviance: 60.31756252213876
println("p_D: $(mean_deviance - deviance_at_mean)") # prints p_D: 4.72884678655096
println("DIC: $(2 * mean_deviance - deviance_at_mean)") # prints DIC: 65.04640930868972
describe(chains)
While my R code prints (Which is correct as per the course)
Mean deviance: 60.26
penalty 14.68
Penalized deviance: 74.94