How to retrieve gradient value in custom Flux training loop?

Hello,

I have a few questions about custom training loops in Flux. According to the documentation, I should write something like that :

function my_custom_train!(loss, ps, data, opt)
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Insert what ever code you want here that needs Training loss, e.g. logging
      return training_loss
    end
    # insert what ever code you want here that needs gradient
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end
end
  1. If the ps argument is Flux.params(my_model) like in the Flux.train! method, is the ps = Params(ps) row redundant ?
  2. gs seems to be an instance of Zygote.Grads, how can I retrieve gradient value in the loop ? I tried gs[ps] and gs[my_model] without any success.
  1. Yes, it shouldn’t be needed.
  2. The Grads type contains the fields grads and params, so you can access them as gs.grads.
  1. If the ps argument is Flux.params(my_model) like in the Flux.train! method, is the ps = Params(ps) row redundant ?

yes

  1. gs seems to be an instance of Zygote.Grads, how can I retrieve gradient value in the loop ? I tried gs[ps] and gs[my_model] without any success.
for p in ps
   print(gs[p])
end

Thank you for your help ! This is exactly what I was looking for.