Flux: Custom Training + Logging

I just saw an even better answer. DiffEqFlux.jl uses a similar pattern in its custom training loop, but moves the logging call outside the gradient loop by declaring the loss to be local. This works and is a bit easier to understand if you ask me:

function my_custom_train!(loss, ps, data, opt)
  # declare training loss local so we can use it outside gradient calculation
  local training_loss                                                            
  ps = Params(ps)                                                                   
  for d in data                                                                     
    gs = gradient(ps) do                                                            
      training_loss = loss(d...)
    end                                                                             
    # Insert what ever code you want here that needs Training loss, e.g. logging
    evalcb(training_loss) # eventually want to pass out training_loss...            
                                                                                    
    # insert what ever code you want here that needs gradient                       
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    Flux.update!(opt, ps, gs)                                                       
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end                                                                               
end