Just posted this to StackOverflow, but crossposting here for Julia reach.
TL;DR – when making a simple CNN I noticed weird behavior between the reported loss during training and the loss I actually calculated from the “trained” model (trained model looked like garbage despite loss saying it should be quite good). The crux of the problem I identified is that when having a BatchNorm
layer, the loss value calculated appears to mutate after calling withgradient
.
For example:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
produces different values for each of the loss calculations (sometimes very different), whereas running a simple model without a BatchNorm
layer:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
produces the same output (which is what I expected).
Why does adding BatchNorm
do this, and is this intended? Am I using it wrong?