How to retrieve gradient value in custom Flux training loop?

Hello,

I have a few questions about custom training loops in Flux. According to the documentation, I should write something like that :

function my_custom_train!(loss, ps, data, opt)
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Insert what ever code you want here that needs Training loss, e.g. logging
      return training_loss
    end
    # insert what ever code you want here that needs gradient
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end
end
  1. If the ps argument is Flux.params(my_model) like in the Flux.train! method, is the ps = Params(ps) row redundant ?
  2. gs seems to be an instance of Zygote.Grads, how can I retrieve gradient value in the loop ? I tried gs[ps] and gs[my_model] without any success.
  1. Yes, it shouldn’t be needed.
  2. The Grads type contains the fields grads and params, so you can access them as gs.grads.
1 Like
  1. If the ps argument is Flux.params(my_model) like in the Flux.train! method, is the ps = Params(ps) row redundant ?

yes

  1. gs seems to be an instance of Zygote.Grads, how can I retrieve gradient value in the loop ? I tried gs[ps] and gs[my_model] without any success.
for p in ps
   print(gs[p])
end
1 Like

Thank you for your help ! This is exactly what I was looking for.