I have a few questions about custom training loops in Flux. According to the documentation, I should write something like that :
function my_custom_train!(loss, ps, data, opt) ps = Params(ps) for d in data gs = gradient(ps) do training_loss = loss(d...) # Insert what ever code you want here that needs Training loss, e.g. logging return training_loss end # insert what ever code you want here that needs gradient # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge update!(opt, ps, gs) # Here you might like to check validation set accuracy, and break out to do early stopping end end
- If the
Flux.params(my_model)like in the
Flux.train!method, is the
ps = Params(ps)row redundant ?
gsseems to be an instance of
Zygote.Grads, how can I retrieve gradient value in the loop ? I tried
gs[my_model]without any success.