Why doesn't the loss calculated by Flux `withgradient` match what I have calculated?

Just posted this to StackOverflow, but crossposting here for Julia reach.

TL;DR – when making a simple CNN I noticed weird behavior between the reported loss during training and the loss I actually calculated from the “trained” model (trained model looked like garbage despite loss saying it should be quite good). The crux of the problem I identified is that when having a BatchNorm layer, the loss value calculated appears to mutate after calling withgradient.

For example:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

produces different values for each of the loss calculations (sometimes very different), whereas running a simple model without a BatchNorm layer:

using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")

produces the same output (which is what I expected).

Why does adding BatchNorm do this, and is this intended? Am I using it wrong?

I presume you’re used to PyTorch and having to manually specify whether a layer is in train or test/eval mode? Flux does this automatically by default, so taking gradients of a model will let BatchNorm detect that it should be updating running stats. If you’d like more manual control, trainmode! and testmode! are available per the docs.

thanks, yes applying those explicitly makes the values the same!