Zygote: how to get intermediate values without evaluating function twice?

If you just wanted the loss, you could use:

val, grad = withgradient(loss, x, y, z)

which just does this:

val, back = pullback(loss, x, y, z)
grad = back(1.0)  # or Zygote.sensitivity(val) instead of 1.0

To keep more things, you can extend this pattern:

(val, A, B), back = pullback(more, x, y, z)
grad = back(1.0)
2 Likes