Flux opt_state question

According to the document,

# Initialise the optimiser for this model:
opt_state = Flux.setup(rule, model)

for data in train_set
  # Unpack this element (for supervised training):
  input, label = data

  # Calculate the gradient of the objective
  # with respect to the parameters within the model:
  grads = Flux.gradient(model) do m
      result = m(input)
      loss(result, label)

  # Update the parameters so as to reduce the objective,
  # according the chosen optimisation rule:
  Flux.update!(opt_state, model, grads[1])

will compute the gradient of loss function.

Does this mean that if I don’t set up opt_state, no tracks will be recorded just as if
with torch.no_grads() is applied in Pytorch, so no waste of computation in case no gradient is required?

1 Like

Zygote (which is the automatic differentiation engine underlying Flux) works differently from pytorch’s autograd: it doesn’t keep a tape. If you want to compute gradients, you call Flux.gradient, if you don’t want to compute gradient you just call the function.
The optimizer has nothing to do with it, the optimizer just handle how the gradient should be used to update the model.


Thanks for the clarification!