Amortized Hierarchical Variational Model

Suppose I have a complex neural network giving four parameters mu, precision, alpha, beta, which may be fixed for all training, learned during training or varying for each batch, such as:

X,Y = next_batch(data)
(mu,precision,alpha,beta) = ComplexNN(X)

Then I have Parent := NormalInverseGamma(mu,precision,alpha,beta), Q := Normal(mean, variance), and P := Normal(mean_prior, variance_prior). Where (mean, variance) ~ Parent and target ~ Q. The target samples will be used by some other task which has it’s own loss. How could I perform the training with L = LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss?

You can increment the log joint probability manually in a Turing model using acclogp!(_varinfo, lp) where lp is your new “loss”. Technically, it would be negative the “loss” added because in an optimization context we would maximize the log probability so its negative is the “loss” minimized.

1 Like

Thank you for your reply. I have a few questions about it. So lp = -(LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss), right? Also, what would _varinfo be? One more thing, will ADVI update the parameters of my (Flux) NN out-of-the-box, or do I need to register them somehow?

Thanks in advance.

I think you may not need the negative sign here. I was mostly talking about other_task_loss term which I don’t know what it refers to in this context. But the term “loss” means that we are interested in its low values while we are interested in the high values of log likelihood for example. Could be just a terminology confusion thing and you didn’t mean to imply “loss”.

Turing treats all parameters equally, NN parameters or not, that’s irrelevant. There is a tutorial in the docs on Bayesian NN. I don’t know if we have an ADVI one. @torfjelde might know.

_varinfo is a reserved term only available inside the @model body that lets you access the internal data structure we use to track random variables and log probabilities.

1 Like

Thank you for you answers!

I was mostly talking about other_task_loss term which I don’t know what it refers to in this context

It depends on the final task, but it can be Cross Entropy.

I have a question regarding the ADVI part, do you have any idea on how efficient this is in practice (wrt training time) compared to amortized VI?