Why doesn't the loss calculated by Flux `withgradient` match what I have calculated?

Just posted this to StackOverflow, but crossposting here for Julia reach.

TL;DR – when making a simple CNN I noticed weird behavior between the reported loss during training and the loss I actually calculated from the “trained” model (trained model looked like garbage despite loss saying it should be quite good). The crux of the problem I identified is that when having a BatchNorm layer, the loss value calculated appears to mutate after calling withgradient.

For example:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

produces different values for each of the loss calculations (sometimes very different), whereas running a simple model without a BatchNorm layer:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

produces the same output (which is what I expected).

Why does adding BatchNorm do this, and is this intended? Am I using it wrong?

I presume you’re used to PyTorch and having to manually specify whether a layer is in train or test/eval mode? Flux does this automatically by default, so taking gradients of a model will let BatchNorm detect that it should be updating running stats. If you’d like more manual control, trainmode! and testmode! are available per the docs.

1 Like

thanks, yes applying those explicitly makes the values the same!