I just saw an even better answer. DiffEqFlux.jl uses a similar pattern in its custom training loop, but moves the logging call outside the gradient loop by declaring the loss to be local
. This works and is a bit easier to understand if you ask me:
function my_custom_train!(loss, ps, data, opt)
# declare training loss local so we can use it outside gradient calculation
local training_loss
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
end
# Insert what ever code you want here that needs Training loss, e.g. logging
evalcb(training_loss) # eventually want to pass out training_loss...
# insert what ever code you want here that needs gradient
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
Flux.update!(opt, ps, gs)
# Here you might like to check validation set accuracy, and break out to do early stopping
end
end