BatchNorm gradient is inconsistent with PyTorch in training mode

I found Flux’s BatchNorm output in training/test mode and gradient in test mode is consistent with PyTorch but gradient in training mode is not. It’s kind of strange since all parameters look the same.

Example:

using Flux
using PyCall

const torch = pyimport("torch")
const nn = pyimport("torch.nn")

bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)

# Following code shows consistency between initial parameters
bn1.ϵ, bn1_py.eps
# (1.0f-5, 1.0e-5)

bn1.momentum, bn1_py.momentum 
# (0.1f0, 0.1)

bn1.γ, bn1_py.weight
#=
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True))
=#

bn1.β, bn1_py.bias
#=
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True))
=#

bn1.μ, bn1_py.running_mean
#=
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
=#

bn1.σ², bn1_py.running_var
#=
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))
=#

dummy_input_py = torch.rand(2, 64, 50, 40)
dummy_input = dummy_input_py.numpy()
dummy_input = permutedims(dummy_input, [3, 4, 2, 1]);

dummy_input |> size, dummy_input |> size

# ((50, 40, 64, 2), (50, 40, 64, 2))

bn1_py.eval();
Flux.testmode!(bn1);

bn1_py(dummy_input_py).sum().backward()

ps = params(bn1)

let bn1_let = bn1, dummy_input_let = dummy_input
    global gs = gradient(ps) do
        dummy_input_let |> bn1_let |> sum
    end
end

bn1_py.weight.grad, gs.grads[gs.params[2]]
#=
(PyObject tensor([2019.4570, 1947.2040, 1988.5925, 1969.9722, 2009.9149, 1998.1893,
        1988.4197, 2009.7001, 1977.4849, 1994.3187, 1997.6350, 2005.8318,
        1991.8772, 1991.1735, 2011.2142, 1988.0229, 1989.1785, 1997.5776,
        1979.7791, 2007.7289, 2039.7695, 2016.0405, 1984.9316, 2027.1691,
        2026.3792, 1983.8584, 1993.3418, 1982.2946, 2013.3733, 1997.2848,
        2013.8619, 2017.9714, 1998.8685, 1962.2557, 1991.9044, 1965.8220,
        2045.9636, 1967.6460, 2016.0278, 2029.5718, 1997.0713, 2016.2992,
        2024.5354, 1985.1301, 2032.4496, 2030.5781, 1976.5828, 2002.4253,
        1993.6927, 2024.0828, 2004.3256, 2001.8090, 1986.0214, 1977.4012,
        2013.7397, 1996.6542, 1997.0526, 1980.2206, 1996.9481, 2023.0551,
        1989.6249, 1962.9667, 2011.6012, 2005.3535]), Float32[2019.456, 1947.2045, 1988.5906, 1969.9738, 2009.9119, 1998.1907, 1988.4194, 2009.6998, 1977.485, 1994.3193  …  2013.7375, 1996.6555, 1997.0505, 1980.2198, 1996.9487, 2023.0547, 1989.6229, 1962.9647, 2011.6019, 2005.3547])
=#

# Create new pairs to prevent gradient accumulation
bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)

bn1_py.train();
Flux.trainmode!(bn1);

ps = params(bn1)

bn1_py(dummy_input_py).sum().backward()
let bn1_let = bn1, dummy_input_let = dummy_input
    global gs = gradient(ps) do
        dummy_input_let |> bn1_let |> sum
    end
end

bn1_py.weight.grad, gs.grads[gs.params[2]]
#=
(PyObject tensor([ 1.6493e-04, -2.9656e-05, -2.3891e-05,  1.3843e-04,  2.7258e-04,
         1.7511e-04,  1.8044e-04,  2.8649e-04,  1.0939e-04, -1.1944e-04,
         1.5860e-04,  2.8925e-04, -6.8589e-05,  8.4068e-06,  5.5734e-05,
        -1.2978e-04,  6.3925e-05, -1.8272e-04, -1.9109e-04,  1.8461e-04,
         3.8091e-04,  6.6808e-05,  1.6589e-04, -2.8525e-04, -1.2947e-04,
        -1.2776e-05, -8.5190e-05,  1.0898e-04, -2.3918e-04, -1.5324e-04,
         2.2328e-04,  5.4052e-05,  3.6461e-05,  1.5121e-04,  1.5499e-05,
         1.4421e-04,  1.5866e-04, -1.7723e-04, -5.3185e-05, -4.0692e-04,
        -2.4080e-05,  1.2942e-04,  3.0721e-04, -1.2163e-05,  1.2194e-04,
        -1.5010e-04,  1.3664e-04,  3.5152e-04,  1.3920e-04,  1.6431e-06,
        -6.2837e-05, -3.8363e-04,  1.0273e-04,  2.3168e-05,  4.7588e-05,
        -1.7416e-04, -8.3082e-05, -6.8849e-05,  3.0986e-05, -3.8488e-04,
        -5.8636e-05,  0.0000e+00,  1.6420e-04,  3.3613e-04]), Float32[0.0042684674, -0.0037716627, 0.001232624, -0.0019351244, 0.0068135858, -0.007997453, 0.0022816956, 0.0019184351, -0.0035838783, 0.0003259778  …  3.325939f-5, -0.009773254, 0.0076975375, -0.00045597553, -0.0036969185, -0.0012117028, 0.007847309, 0.0020318031, -0.001496315, -0.0012641922])

=#

3 Likes

It seems like that it’s just a tiny difference which is due to computation implementation detail. If I use a normal random weight instead of sum, the error is negligible. On the otherhand there’s a sigificant error on gradient of MaxPool, see issue: