I noticed something odd (or at least that I don’t understand) with the Zygote.pullback
method.
Let’s consider a very simple test case:
using Zygote
losses, back = Zygote.pullback(x -> (sum(x), sum(x)), [1.0,1.0])
Here we have a tuple-valued loss function, for which I would expect back
to handle each entry separately, i.e. B_1 = \frac{\partial y_1}{\partial \ell_1}\frac{\partial \ell_1}{\partial p}, B_2 = \frac{\partial y_2}{\partial \ell_2}\frac{\partial \ell_2}{\partial p}.
However, that isn’t what happens. Instead, it seems like back
treats the loss tuple as a vector and then sums over the results.
grads = back((1.0,1.0))
# output
(Fill(2.0, 2),)
and where does the Fill
come from?
Am I missing something? Or is this a bug?
Thanks!