Mamba Chains declaration for matrix of parameters

This might be a stupid question but I just want to make sure I am doing it correctly. I have read the document but is still unsure.

I am trying to use Mamba’s NUTSVariate to implement a Bayesian neural network. My parameters are then the weight matrices and bias vectors.

Suppose I have weight matrix \alpha\in\mathbb{R}^{6\times5}, \beta\in\mathbb{R}^5, my codes look roughly like

α = rand(Normal(),6,5)
β = rand(Normal(),1,5)
Θ = [vec(α),vec(β)] 
function logfgrad(Θ::DenseVector)
    θ = reshape(Θ[1],J,K)
    α = reshape(Θ[2],J,K)
    loglik = ...
    Δα = ... #A 6 by 5 array of all gradients α[i,j]
    Δβ = ... #A 1 by 5 array of all gradients β[j]
    grad = [Δα, Δβ]
    return loglik, grad
end

n_samp = 10000
burnin = 5000
sim = Chains(iters = n_samp, params=?, start=(burnin+1), names=?)
samp = NUTSVariate(Θ, logfgrad)

My questions are:

Are the input vector \Theta and output vector grad specified correctly? Currently they are in the form \Theta = [[\alpha_{1,1},\alpha_{2,1},...,\alpha_{6,5}] ,[\beta_1,...\beta_5]] and \textbf{grad}=[[\Delta\alpha_{1,1},\Delta\alpha_{2,1},...,\Delta\alpha_{6,5}], [\Delta\beta_1,...\Delta\beta_5]], should I instead reshape them into \Theta = [\alpha_{1,1},\alpha_{2,1},...,\alpha_{6,5},\beta_1,...\beta_5], etc?

I think the field params is the number of parameters right? In this case it should be length(Θ)?
For names do I need to give a name for each single parameter, like ["α[1,1]","α[2,1]",...,"β[5]"]?

Thank you for all the help!

This doesn’t answer your question directly, but I recommend Turing or DynamicHMC, as they are more activity maintained and developed. Turing in particular is geared towards machine learning. You might find some useful information in this thread.

Hi Christopher,

Thank you for the recommendations. DynamicHMC seems interesting, however is there any resource that provide examples with self-coded gradient function? I would like to avoid automatic differentiation.

@Tamas_Papp

Also see https://github.com/TuringLang/AdvancedHMC.jl.

It is very simple, see

https://tamaspapp.eu/LogDensityProblems.jl/dev/#Manually-calculated-derivatives-1

for a template. This, however, applies to log density functions of \mathbb{R}^n \to \mathbb{R} — if your domain is constrained you have to account for transformations yourself.

Feel free to ping me if you need help with this.

Thanks! This is very useful.