Zygote: how to get intermediate values without evaluating function twice?

Perhaps my issue is best illustrated with a toy example. I have a function loss(x, y, z). This function returns a scalar value MSE; however, it also calculates some intermediate values during the computation, which I also return so that they can be used for plotting purposes. So my loss function returns A, B, C, MSE, where A, B, and C are some intermediate values that I might want to plot as a function of the gradient descent.

However, I only want the gradient of the function with respect to the MSE, while I want all of the primal variables A, B, C and MSE. To achieve the first goal, it seems like I have to provide Zygote a function x,y,z -> loss(x, y, z)[4] – but then I can’t retrieve the primal variables A, B and C, even if I use withgradient. So my only choice seems to evaluate the function again, separately from Zygote.gradient, which doubles the work done in forward computation. In principle, what I want should definitely be possible – but how can I achieve this in practice?? Am I missing something obvious?

If you just wanted the loss, you could use:

val, grad = withgradient(loss, x, y, z)

which just does this:

val, back = pullback(loss, x, y, z)
grad = back(1.0)  # or Zygote.sensitivity(val) instead of 1.0

To keep more things, you can extend this pattern:

(val, A, B), back = pullback(more, x, y, z)
grad = back(1.0)

To clarify, I’d need to put something slightly different into the back function in the last example, right? Such as back((1.0, 0., 0.))? Or would back((1.0, nothing, nothing)) be more appropriate, or some other such way?

Oh right, I didn’t try it, sorry. Indeed you want nothing, which is what Zygote uses to indicate certainly-zero, of any type.

julia> more(x,y,z) = sum(x*y*z), x+y, (y+z)[1];

julia> x, y, z = eachslice(reshape(1:12, 2, 2, 3), dims=3);

julia> (val, A, B), back = pullback(more, x, y, z)
((2834, [6 10; 8 12], 14), Zygote.var"#52#53"{typeof(∂(more))}(∂(more)))

julia> back((1.0, nothing, nothing))
([254.0 296.0; 254.0 296.0], [60.0 66.0; 140.0 154.0], [57.0 57.0; 77.0 77.0])

julia> back((1.0, 0, 0))
ERROR: MethodError: no method matching +(::Int64, ::Matrix{Float64})

julia> back((1.0, zero(A), zero(B)))
([254.0 296.0; 254.0 296.0], [60.0 66.0; 140.0 154.0], [57.0 57.0; 77.0 77.0])```
1 Like

Using a closure and externally scoped variables/some mutable output type** is also a perfectly acceptable way to do things:

local A
local B
local C
withgradient(x, y, z) do
  A, B, C, MSE = loss(x, y, z)
  return MSE

** Zygote.ignore is highly recommended to avoid errors if you use a collection like an array for this.

1 Like