Flux Get Forward Pass Results when Taking Gradient

I was hoping to keep track of loss values while running training with Flux.jl. However, so far I have been unable to both evaluate my loss gradient and evaluate my loss at the same time. I have to run them separately:

ℓ = get_loss(stuff)
∇ = gradient(() -> get_loss(stuff), params)

However, I assume the gradient has to perform a forward pass during eval. I was hoping to define my own model that would store intermediate results in pre-allocated arrays, but Zygote doesn’t like setindex!.

Is there a way to both get the results of a forward pass and run a gradient eval in one go? Maybe there is a way other than setindex! to store results during a forward pass?


You should be able to do so by using Zygote.pullback() as per the very last code snippet on this page: Training · Flux

1 Like


1 Like

Amazing! Thank you