Perhaps my issue is best illustrated with a toy example. I have a function loss(x, y, z). This function returns a scalar value MSE; however, it also calculates some intermediate values during the computation, which I also return so that they can be used for plotting purposes. So my loss function returns A, B, C, MSE, where A, B, and C are some intermediate values that I might want to plot as a function of the gradient descent.
However, I only want the gradient of the function with respect to the MSE, while I want all of the primal variables A, B, C and MSE. To achieve the first goal, it seems like I have to provide Zygote a function x,y,z -> loss(x, y, z)[4] – but then I can’t retrieve the primal variables A, B and C, even if I use withgradient. So my only choice seems to evaluate the function again, separately from Zygote.gradient, which doubles the work done in forward computation. In principle, what I want should definitely be possible – but how can I achieve this in practice?? Am I missing something obvious?
To clarify, I’d need to put something slightly different into the back function in the last example, right? Such as back((1.0, 0., 0.))? Or would back((1.0, nothing, nothing)) be more appropriate, or some other such way?