Flux: Custom Training + Logging

In the Flux documentation, they give an example of how to do a custom training routine and they indicate where you would place code to do logging (see below). I want to log the loss by initializing LossLog = Float64[] outside the function, then calling push!(LossLog, training_loss). I’ve tried both placing this snippet just above the return training_loss statement and also placing it before update!, but both yield errors saying Mutating arrays not supported that appears to be coming from Zygote.

# Unchanged code from documentation. I got errors when I tried 
# adding a `push!` statement to log the loss as described above.
function my_custom_train!(loss, ps, data, opt)
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Insert what ever code you want here that needs Training loss, e.g. logging
      return training_loss
    end
    # insert what ever code you want here that needs gradient
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end
end

What is the problem, and what do I need to do to make this work?

Also, I want to understand what the code is doing. I understand (I hope!) the do block syntax as described in the docs. But I’m not sure how the assignment gs = factors in. Naively, I’d guess that the code is applying the function loss(d...) to each element of the collection gradient(ps) and the output after all this is then saved to a variable gs.

However, gradient(ps) is, I think, [∇W1, ∇b1, ∇W2, ∇b2, …]. And there has to be a way to get the loss to update!; there’s no explicit input of the loss, so it must be included in gs. But if that’s true, the do block is making a tuple of the gradients and the loss in a manner I’m not familiar with (which isn’t saying much). Point being, I’m confused and could use some help.

Edit: it’s of course the gradient of the loss function that needs to be passed to update!. I also better understand do blocks now and I see that Flux’s example code is assigning to gs the output of gradient(loss(d...), ps). And so, for a loss L, I think we have

gs = [∇W1L, ∇b1L, ∇W2L, ∇b2L, …]

For clarity, I place a call evalcb() after training_loss = loss(d...). I simplified the callback to eliminate the push! command and I pre-allocate an array to store values:

N = 100000
idx = 1
LossLog = Array{Float64, 1}(undef, N)
function evalcb()
    global idx
    global LossLog
    @show typeof(LossLog), size(LossLog) # verifies we are accessing global
    if idx < N
        # FAILS if line below is uncommented.
        LossLog[idx] = 0.1 # dummy value
        
        if true
            idx += 1
            println("Next idx = "*string(idx)*"; LossLog[1]="*string(LossLog[1])) #) idx, LossLog[1] # prints once before failure
        end
    end
end

If I comment out the array assignment, it runs: it’s able to increment idx as I can see from the print statement. But with the array assignment in, running my_custom_train gives the error:

(typeof(LossLog), size(LossLog)) = (Array{Float64,1}, (100000,))
Next idx = 2; LossLog[1]=0.1

Mutating arrays is not supported

Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.var"#1048#1049")(::Nothing) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\array.jl:61
 [3] (::Zygote.var"#2775#back#1050"{Zygote.var"#1048#1049"})(::Nothing) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [4] evalcb at .\In[127]:39 [inlined]
 [5] (::typeof(∂(evalcb)))(::Nothing) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [6] #113 at .\In[127]:60 [inlined]
 [7] (::typeof(∂(λ)))(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [8] (::Zygote.var"#49#50"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:179
 [9] gradient(::Function, ::Params) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:55
 [10] my_custom_train!(::Function, ::Params, ::DataLoader, ::ADAM) at .\In[127]:57
 [11] top-level scope at .\In[127]:73

The full code to reproduce the error is below. Thanks in advance for any help or insights you can provide!

using Distributions
using Plots
using Flux
using Flux: param, mse
using Flux.Data: DataLoader
using Zygote
using Zygote: Params

##### DATA #####################
num_samples = 50
x_noise_std = 0.01
y_noise_std = 0.25
function generate_linear_data()
    x = reshape(range(-1, stop=1, length=num_samples), num_samples, 1)
    x_noise = rand(Normal(0,x_noise_std), num_samples)
    y_noise = rand(Normal(0,y_noise_std), num_samples)
    
    y = 3 .* x .+ y_noise
    
    x = transpose(x)
    y = transpose(y)
    
    return x, y
end
X, Y = generate_linear_data() # Training data of shape (1,num_samples)

train_loader = DataLoader(X, Y, batchsize=10, shuffle=true) 

##### CALLBACK #################
N = 100000
idx = 1
LossLog = Array{Float64, 1}(undef, N)
function evalcb()
    global idx
    global LossLog
    @show typeof(LossLog), size(LossLog) # verifies we are accessing global
    if idx < N
        # FAILS if line below is uncommented.
        LossLog[idx] = 0.1 # dummy value
        
        if true
            idx += 1
            println("Next idx = "*string(idx)*"; LossLog[1]="*string(LossLog[1])) #) idx, LossLog[1] # prints once before failure
        end
    end
end

##### MODEL & TRAINING #####################
m = Chain(Dense(size(X, 1), 10, tanh), Dense(10, 10, tanh), Dense(10, size(Y,1), tanh))
opt = ADAM()
loss(x, y) = mse(m(x), y)

# From https://fluxml.ai/Flux.jl/v0.10/training/training/#Custom-Training-loops-1
function my_custom_train!(loss, ps, data, opt)
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Insert what ever code you want here that needs Training loss, e.g. logging
      evalcb() # eventually want to pass out training_loss...
      return training_loss
    end

    # insert what ever code you want here that needs gradient
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    Flux.update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end
end

for epoch in 1:100
    for (x, y) in train_loader
        my_custom_train!(loss, Flux.params(m), train_loader, opt)
    end
end

Zygote has two tools for stopping gradients, I think ignore is the one you want.

dropgrad can be used in an expression to remove the enclosed variable from the backwards pass.

ignore can be used to remove a whole block from the backward pass.

dropgrad does not seem to apply to assignment, so I think you can use ignore like this:

Zygote.ignore() do
    evalcb(training_loss)
end
1 Like

I just saw an even better answer. DiffEqFlux.jl uses a similar pattern in its custom training loop, but moves the logging call outside the gradient loop by declaring the loss to be local. This works and is a bit easier to understand if you ask me:

function my_custom_train!(loss, ps, data, opt)
  # declare training loss local so we can use it outside gradient calculation
  local training_loss                                                            
  ps = Params(ps)                                                                   
  for d in data                                                                     
    gs = gradient(ps) do                                                            
      training_loss = loss(d...)
    end                                                                             
    # Insert what ever code you want here that needs Training loss, e.g. logging
    evalcb(training_loss) # eventually want to pass out training_loss...            
                                                                                    
    # insert what ever code you want here that needs gradient                       
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
    Flux.update!(opt, ps, gs)                                                       
    # Here you might like to check validation set accuracy, and break out to do early stopping
  end                                                                               
end                                                                                 

Thanks. That’s actually what I ended up doing late yesterday – just pulling the callback out of the do block.

Having read much more on do blocks (so I mostly understand them), I now see that Flux’s example code is assigning to gs the output of gradient(loss(d...), ps). By placing the logging in the do block (as the example comment had indicated), the callback is also passed into gradient! That’s not really what we want, and that’s why it might require using Zygote’s dropgrad or ignore as you indicated in your original reply.

Placing the logging after the do block is certainly easiest, and it is probably just as efficient as logging inside the do block (if not more so).

Would be good to update the docs for this

How about this?

That’s great! Thanks for doing that.