I found Flux’s BatchNorm
output in training/test mode and gradient in test mode is consistent with PyTorch but gradient in training mode is not. It’s kind of strange since all parameters look the same.
Example:
using Flux
using PyCall
const torch = pyimport("torch")
const nn = pyimport("torch.nn")
bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)
# Following code shows consistency between initial parameters
bn1.ϵ, bn1_py.eps
# (1.0f-5, 1.0e-5)
bn1.momentum, bn1_py.momentum
# (0.1f0, 0.1)
bn1.γ, bn1_py.weight
#=
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True))
=#
bn1.β, bn1_py.bias
#=
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
requires_grad=True))
=#
bn1.μ, bn1_py.running_mean
#=
(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], PyObject tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
=#
bn1.σ², bn1_py.running_var
#=
(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], PyObject tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))
=#
dummy_input_py = torch.rand(2, 64, 50, 40)
dummy_input = dummy_input_py.numpy()
dummy_input = permutedims(dummy_input, [3, 4, 2, 1]);
dummy_input |> size, dummy_input |> size
# ((50, 40, 64, 2), (50, 40, 64, 2))
bn1_py.eval();
Flux.testmode!(bn1);
bn1_py(dummy_input_py).sum().backward()
ps = params(bn1)
let bn1_let = bn1, dummy_input_let = dummy_input
global gs = gradient(ps) do
dummy_input_let |> bn1_let |> sum
end
end
bn1_py.weight.grad, gs.grads[gs.params[2]]
#=
(PyObject tensor([2019.4570, 1947.2040, 1988.5925, 1969.9722, 2009.9149, 1998.1893,
1988.4197, 2009.7001, 1977.4849, 1994.3187, 1997.6350, 2005.8318,
1991.8772, 1991.1735, 2011.2142, 1988.0229, 1989.1785, 1997.5776,
1979.7791, 2007.7289, 2039.7695, 2016.0405, 1984.9316, 2027.1691,
2026.3792, 1983.8584, 1993.3418, 1982.2946, 2013.3733, 1997.2848,
2013.8619, 2017.9714, 1998.8685, 1962.2557, 1991.9044, 1965.8220,
2045.9636, 1967.6460, 2016.0278, 2029.5718, 1997.0713, 2016.2992,
2024.5354, 1985.1301, 2032.4496, 2030.5781, 1976.5828, 2002.4253,
1993.6927, 2024.0828, 2004.3256, 2001.8090, 1986.0214, 1977.4012,
2013.7397, 1996.6542, 1997.0526, 1980.2206, 1996.9481, 2023.0551,
1989.6249, 1962.9667, 2011.6012, 2005.3535]), Float32[2019.456, 1947.2045, 1988.5906, 1969.9738, 2009.9119, 1998.1907, 1988.4194, 2009.6998, 1977.485, 1994.3193 … 2013.7375, 1996.6555, 1997.0505, 1980.2198, 1996.9487, 2023.0547, 1989.6229, 1962.9647, 2011.6019, 2005.3547])
=#
# Create new pairs to prevent gradient accumulation
bn1_py = nn.BatchNorm2d(64)
bn1 = BatchNorm(64)
bn1_py.train();
Flux.trainmode!(bn1);
ps = params(bn1)
bn1_py(dummy_input_py).sum().backward()
let bn1_let = bn1, dummy_input_let = dummy_input
global gs = gradient(ps) do
dummy_input_let |> bn1_let |> sum
end
end
bn1_py.weight.grad, gs.grads[gs.params[2]]
#=
(PyObject tensor([ 1.6493e-04, -2.9656e-05, -2.3891e-05, 1.3843e-04, 2.7258e-04,
1.7511e-04, 1.8044e-04, 2.8649e-04, 1.0939e-04, -1.1944e-04,
1.5860e-04, 2.8925e-04, -6.8589e-05, 8.4068e-06, 5.5734e-05,
-1.2978e-04, 6.3925e-05, -1.8272e-04, -1.9109e-04, 1.8461e-04,
3.8091e-04, 6.6808e-05, 1.6589e-04, -2.8525e-04, -1.2947e-04,
-1.2776e-05, -8.5190e-05, 1.0898e-04, -2.3918e-04, -1.5324e-04,
2.2328e-04, 5.4052e-05, 3.6461e-05, 1.5121e-04, 1.5499e-05,
1.4421e-04, 1.5866e-04, -1.7723e-04, -5.3185e-05, -4.0692e-04,
-2.4080e-05, 1.2942e-04, 3.0721e-04, -1.2163e-05, 1.2194e-04,
-1.5010e-04, 1.3664e-04, 3.5152e-04, 1.3920e-04, 1.6431e-06,
-6.2837e-05, -3.8363e-04, 1.0273e-04, 2.3168e-05, 4.7588e-05,
-1.7416e-04, -8.3082e-05, -6.8849e-05, 3.0986e-05, -3.8488e-04,
-5.8636e-05, 0.0000e+00, 1.6420e-04, 3.3613e-04]), Float32[0.0042684674, -0.0037716627, 0.001232624, -0.0019351244, 0.0068135858, -0.007997453, 0.0022816956, 0.0019184351, -0.0035838783, 0.0003259778 … 3.325939f-5, -0.009773254, 0.0076975375, -0.00045597553, -0.0036969185, -0.0012117028, 0.007847309, 0.0020318031, -0.001496315, -0.0012641922])
=#