BatchNorm gradient is inconsistent with PyTorch in training mode

I found Flux’s BatchNorm output in training/test mode and gradient in test mode is consistent with PyTorch but gradient in training mode is not. It’s kind of strange since all parameters look the same.


using Flux
using PyCall

const torch = pyimport("torch")
const nn = pyimport("torch.nn")

bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)

# Following code shows consistency between initial parameters
bn1.ϵ, bn1_py.eps
# (1.0f-5, 1.0e-5)

bn1.momentum, bn1_py.momentum 
# (0.1f0, 0.1)

bn1.γ, bn1_py.weight
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True))

bn1.β, bn1_py.bias
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

bn1.μ, bn1_py.running_mean
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

bn1.σ², bn1_py.running_var
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))

dummy_input_py = torch.rand(2, 64, 50, 40)
dummy_input = dummy_input_py.numpy()
dummy_input = permutedims(dummy_input, [3, 4, 2, 1]);

dummy_input |> size, dummy_input |> size

# ((50, 40, 64, 2), (50, 40, 64, 2))



ps = params(bn1)

let bn1_let = bn1, dummy_input_let = dummy_input
    global gs = gradient(ps) do
        dummy_input_let |> bn1_let |> sum

bn1_py.weight.grad, gs.grads[gs.params[2]]
(PyObject tensor([2019.4570, 1947.2040, 1988.5925, 1969.9722, 2009.9149, 1998.1893,
        1988.4197, 2009.7001, 1977.4849, 1994.3187, 1997.6350, 2005.8318,
        1991.8772, 1991.1735, 2011.2142, 1988.0229, 1989.1785, 1997.5776,
        1979.7791, 2007.7289, 2039.7695, 2016.0405, 1984.9316, 2027.1691,
        2026.3792, 1983.8584, 1993.3418, 1982.2946, 2013.3733, 1997.2848,
        2013.8619, 2017.9714, 1998.8685, 1962.2557, 1991.9044, 1965.8220,
        2045.9636, 1967.6460, 2016.0278, 2029.5718, 1997.0713, 2016.2992,
        2024.5354, 1985.1301, 2032.4496, 2030.5781, 1976.5828, 2002.4253,
        1993.6927, 2024.0828, 2004.3256, 2001.8090, 1986.0214, 1977.4012,
        2013.7397, 1996.6542, 1997.0526, 1980.2206, 1996.9481, 2023.0551,
        1989.6249, 1962.9667, 2011.6012, 2005.3535]), Float32[2019.456, 1947.2045, 1988.5906, 1969.9738, 2009.9119, 1998.1907, 1988.4194, 2009.6998, 1977.485, 1994.3193  …  2013.7375, 1996.6555, 1997.0505, 1980.2198, 1996.9487, 2023.0547, 1989.6229, 1962.9647, 2011.6019, 2005.3547])

# Create new pairs to prevent gradient accumulation
bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)


ps = params(bn1)

let bn1_let = bn1, dummy_input_let = dummy_input
    global gs = gradient(ps) do
        dummy_input_let |> bn1_let |> sum

bn1_py.weight.grad, gs.grads[gs.params[2]]
(PyObject tensor([ 1.6493e-04, -2.9656e-05, -2.3891e-05,  1.3843e-04,  2.7258e-04,
         1.7511e-04,  1.8044e-04,  2.8649e-04,  1.0939e-04, -1.1944e-04,
         1.5860e-04,  2.8925e-04, -6.8589e-05,  8.4068e-06,  5.5734e-05,
        -1.2978e-04,  6.3925e-05, -1.8272e-04, -1.9109e-04,  1.8461e-04,
         3.8091e-04,  6.6808e-05,  1.6589e-04, -2.8525e-04, -1.2947e-04,
        -1.2776e-05, -8.5190e-05,  1.0898e-04, -2.3918e-04, -1.5324e-04,
         2.2328e-04,  5.4052e-05,  3.6461e-05,  1.5121e-04,  1.5499e-05,
         1.4421e-04,  1.5866e-04, -1.7723e-04, -5.3185e-05, -4.0692e-04,
        -2.4080e-05,  1.2942e-04,  3.0721e-04, -1.2163e-05,  1.2194e-04,
        -1.5010e-04,  1.3664e-04,  3.5152e-04,  1.3920e-04,  1.6431e-06,
        -6.2837e-05, -3.8363e-04,  1.0273e-04,  2.3168e-05,  4.7588e-05,
        -1.7416e-04, -8.3082e-05, -6.8849e-05,  3.0986e-05, -3.8488e-04,
        -5.8636e-05,  0.0000e+00,  1.6420e-04,  3.3613e-04]), Float32[0.0042684674, -0.0037716627, 0.001232624, -0.0019351244, 0.0068135858, -0.007997453, 0.0022816956, 0.0019184351, -0.0035838783, 0.0003259778  …  3.325939f-5, -0.009773254, 0.0076975375, -0.00045597553, -0.0036969185, -0.0012117028, 0.007847309, 0.0020318031, -0.001496315, -0.0012641922])



It seems like that it’s just a tiny difference which is due to computation implementation detail. If I use a normal random weight instead of sum, the error is negligible. On the otherhand there’s a sigificant error on gradient of MaxPool, see issue:

1 Like

Just saw your PRs, just thanks man…! I’m just glad someone notices such differences and I don’t have to wonder why my stuff doen’t work.