Suppose I have a complex neural network giving four parameters
mu, precision, alpha, beta, which may be fixed for all training, learned during training or varying for each batch, such as:
X,Y = next_batch(data)
(mu,precision,alpha,beta) = ComplexNN(X)
Then I have
Parent := NormalInverseGamma(mu,precision,alpha,beta),
Q := Normal(mean, variance), and
P := Normal(mean_prior, variance_prior). Where
(mean, variance) ~ Parent and
target ~ Q. The
target samples will be used by some other task which has it’s own loss. How could I perform the training with
L = LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss?
You can increment the log joint probability manually in a Turing model using
acclogp!(_varinfo, lp) where
lp is your new “loss”. Technically, it would be negative the “loss” added because in an optimization context we would maximize the log probability so its negative is the “loss” minimized.
Thank you for your reply. I have a few questions about it. So
lp = -(LogLikelihood(Q,Y) - DKL(Q,P) + other_task_loss), right? Also, what would
_varinfo be? One more thing, will ADVI update the parameters of my (Flux) NN out-of-the-box, or do I need to register them somehow?
Thanks in advance.
I think you may not need the negative sign here. I was mostly talking about
other_task_loss term which I don’t know what it refers to in this context. But the term “loss” means that we are interested in its low values while we are interested in the high values of log likelihood for example. Could be just a terminology confusion thing and you didn’t mean to imply “loss”.
Turing treats all parameters equally, NN parameters or not, that’s irrelevant. There is a tutorial in the docs on Bayesian NN. I don’t know if we have an ADVI one. @torfjelde might know.
_varinfo is a reserved term only available inside the
@model body that lets you access the internal data structure we use to track random variables and log probabilities.
Thank you for you answers!
I was mostly talking about
other_task_loss term which I don’t know what it refers to in this context
It depends on the final task, but it can be Cross Entropy.
I have a question regarding the ADVI part, do you have any idea on how efficient this is in practice (wrt training time) compared to amortized VI?