Issue with Flux BatchNorm

A simple example with input as a 4 x 20 matrix (each feature vector has 4 elements).

This is PyTorch BatchNorm1d, which has essentially the same default set-up as Flux BatchNorm i.e. affine = true, track_stats = true, momentum=0.1, eps=1e-5:

>>> X = torch.rand(20,4)
>>> bn = nn.BatchNorm1d(4)
>>> Y = bn(X)
>>> torch.mean(Y,dim=0), torch.std(Y,dim=0)
(tensor([ 1.8254e-08,  2.3842e-08, -3.8743e-08,  1.3113e-07],
       grad_fn=<MeanBackward1>), tensor([1.0259, 1.0259, 1.0259, 1.0259], grad_fn=<StdBackward0>))

This is Flux:

julia> X = rand(4,20);
julia> bn = BatchNorm(4);
julia> Y = bn(X);
julia> mean(Y, dims=2), std(Y, dims=2)
([0.48133999685014006; 0.44180285698448074; 0.3883897120545002; 0.5147873246939905;;], [0.29354726908422013; 0.28174793700550516; 0.28389460891762774; 0.3641842220050316;;])

Output does not appear to be normalised. However, setting track_stats=false produces normalised output.

julia> bn = BatchNorm(4, track_stats=false);
julia> Y = bn(X);
julia> mean(Y, dims=2), std(Y, dims=2)
([-3.774758283725532e-16; -1.4988010832439614e-16; -1.7208456881689927e-16; 2.2204460492503132e-17;;], [1.0259156929542212; 1.0259103353840449; 1.0259113600126541; 1.0259376410505154;;])

However, Pytorch BatchNorm1d, with track_running_stats=False, returns the same normalised output as the default case. This is to be expected for this example.

>>> bn = nn.BatchNorm1d(4, track_running_stats=False)
>>> Y = bn(X)
>>> torch.mean(Y,dim=0), torch.std(Y,dim=0)
(tensor([ 1.8254e-08,  2.3842e-08, -3.8743e-08,  1.3113e-07],
       grad_fn=<MeanBackward1>), tensor([1.0259, 1.0259, 1.0259, 1.0259], grad_fn=<StdBackward0>))

What is the intuition behind Flux BatchNorm behaving differently? This does not seem to match with the methodology in the original normalisation paper.

Try repeating the tests after running bn.eval() on the PyTorch side and trainmode!(bn) on the Flux side.

Thanks. It was not clear from the documentation what active does in combination with track_stats. I read through the code and it is clear now.