BatchNorm: only track_stats=true supported on gpu

Not able to use BatchNorm with track_stats=false on GPU.

julia> using Flux

julia> model = Chain(Dense(10,5),BatchNorm(5,relu; track_stats=false))
Chain(Dense(10, 5), BatchNorm(5, relu))

julia> modelgpu = gpu(model)
Chain(Dense(10, 5), BatchNorm(5, relu))

julia> x = randn(Float32, 10, 100);

julia> model(x)
5×100 Matrix{Float32}:
 0.897995  1.41563  0.0       0.625536  0.0       …  0.358508  0.302867  0.0       0.0
 0.903481  0.0      0.782808  0.673593  0.0          1.81217   0.277665  0.0       0.0
 1.37133   0.0      1.94733   0.259229  0.0          1.73941   0.0       0.137841  0.287855
 0.179537  0.0      0.316961  0.0       0.446272     0.0       0.0       0.0       0.0
 0.721404  1.4531   0.0       0.0       0.183346     0.0       0.0       1.31021   0.511247

julia> modelgpu(gpu(x))
ERROR: AssertionError: BatchNorm: only track_stats=true supported on gpu
Stacktrace:
 [1] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2}, cache::Nothing)
   @ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:7
 [2] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2})
   @ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:6
 [3] applychain(fs::Tuple{BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}, x::CUDA.CuArray{Float32, 2}) (repeats 2 times)
   @ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36
 [4] (::Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 1}}, BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}})(x::CUDA.CuArray{Float32, 2})
   @ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38
 [5] top-level scope
   @ REPL[40]:1
 [6] top-level scope
   @ ~/.julia/packages/CUDA/3VnCC/src/initialization.jl:81

Julia and package versions:

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)

(jl_VSWxj0) pkg> st
      Status `/tmp/jl_VSWxj0/Project.toml`
  [587475ba] Flux v0.12.3

Any idea why only track_stats=true is supported on GPU?

cuDNN will unconditionally track stats for batchnorm if you invoke cudnnBatchNormalizationForwardTraining (which Flux does). I’m not sure if it’s still possible to have the proper behaviour if cudnnBatchNormalizationForwardInference is used instead when track_stats=false, but that would be the path forward here.

Thanks for pointing out the functions of BatchNorm on GPU.

I’ve opened Allow BatchNorm training on CUDA with `track_stats=false` · Issue #1606 · FluxML/Flux.jl · GitHub. No promises on any kind of timeline, but it’s documented now. Thanks for reporting!

1 Like