In the Flux documentation, they give an example of how to do a custom training routine and they indicate where you would place code to do logging (see below). I want to log the loss by initializing LossLog = Float64[]
outside the function, then calling push!(LossLog, training_loss)
. I’ve tried both placing this snippet just above the return training_loss
statement and also placing it before update!
, but both yield errors saying Mutating arrays not supported
that appears to be coming from Zygote
.
# Unchanged code from documentation. I got errors when I tried
# adding a `push!` statement to log the loss as described above.
function my_custom_train!(loss, ps, data, opt)
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
# Insert what ever code you want here that needs Training loss, e.g. logging
return training_loss
end
# insert what ever code you want here that needs gradient
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
update!(opt, ps, gs)
# Here you might like to check validation set accuracy, and break out to do early stopping
end
end
What is the problem, and what do I need to do to make this work?
Also, I want to understand what the code is doing. I understand (I hope!) the do
block syntax as described in the docs. But I’m not sure how the assignment gs =
factors in. Naively, I’d guess that the code is applying the function loss(d...)
to each element of the collection gradient(ps)
and the output after all this is then saved to a variable gs
.
However, gradient(ps)
is, I think, [∇W1, ∇b1, ∇W2, ∇b2, …]. And there has to be a way to get the loss to update!
; there’s no explicit input of the loss, so it must be included in gs
. But if that’s true, the do
block is making a tuple of the gradients and the loss in a manner I’m not familiar with (which isn’t saying much). Point being, I’m confused and could use some help.
Edit: it’s of course the gradient of the loss function that needs to be passed to update!
. I also better understand do
blocks now and I see that Flux’s example code is assigning to gs
the output of gradient(loss(d...), ps)
. And so, for a loss L, I think we have
gs = [∇W1L, ∇b1L, ∇W2L, ∇b2L, …]