Hello,
I have a few questions about custom training loops in Flux. According to the documentation, I should write something like that :
function my_custom_train!(loss, ps, data, opt)
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
# Insert what ever code you want here that needs Training loss, e.g. logging
return training_loss
end
# insert what ever code you want here that needs gradient
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
update!(opt, ps, gs)
# Here you might like to check validation set accuracy, and break out to do early stopping
end
end
- If the
ps
argument isFlux.params(my_model)
like in theFlux.train!
method, is theps = Params(ps)
row redundant ? -
gs
seems to be an instance ofZygote.Grads
, how can I retrieve gradient value in the loop ? I triedgs[ps]
andgs[my_model]
without any success.