 # Amortized Hierarchical Variational Model

Suppose I have a complex neural network giving four parameters `mu, precision, alpha, beta`, which may be fixed for all training, learned during training or varying for each batch, such as:

``````X,Y = next_batch(data)
(mu,precision,alpha,beta) = ComplexNN(X)
``````

Then I have `Parent := NormalInverseGamma(mu,precision,alpha,beta)`, `Q := Normal(mean, variance)`, and `P := Normal(mean_prior, variance_prior)`. Where `(mean, variance) ~ Parent` and `target ~ Q`. The `target` samples will be used by some other task which has it’s own loss. How could I perform the training with `L = LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss`?

You can increment the log joint probability manually in a Turing model using `acclogp!(_varinfo, lp)` where `lp` is your new “loss”. Technically, it would be negative the “loss” added because in an optimization context we would maximize the log probability so its negative is the “loss” minimized.

1 Like

Thank you for your reply. I have a few questions about it. So `lp = -(LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss)`, right? Also, what would `_varinfo` be? One more thing, will ADVI update the parameters of my (Flux) NN out-of-the-box, or do I need to register them somehow?

I think you may not need the negative sign here. I was mostly talking about `other_task_loss` term which I don’t know what it refers to in this context. But the term “loss” means that we are interested in its low values while we are interested in the high values of log likelihood for example. Could be just a terminology confusion thing and you didn’t mean to imply “loss”.

Turing treats all parameters equally, NN parameters or not, that’s irrelevant. There is a tutorial in the docs on Bayesian NN. I don’t know if we have an ADVI one. @torfjelde might know.

`_varinfo` is a reserved term only available inside the `@model` body that lets you access the internal data structure we use to track random variables and log probabilities.

1 Like

I was mostly talking about `other_task_loss` term which I don’t know what it refers to in this context