Not able to use BatchNorm with track_stats=false on GPU.
julia> using Flux
julia> model = Chain(Dense(10,5),BatchNorm(5,relu; track_stats=false))
Chain(Dense(10, 5), BatchNorm(5, relu))
julia> modelgpu = gpu(model)
Chain(Dense(10, 5), BatchNorm(5, relu))
julia> x = randn(Float32, 10, 100);
julia> model(x)
5×100 Matrix{Float32}:
0.897995 1.41563 0.0 0.625536 0.0 … 0.358508 0.302867 0.0 0.0
0.903481 0.0 0.782808 0.673593 0.0 1.81217 0.277665 0.0 0.0
1.37133 0.0 1.94733 0.259229 0.0 1.73941 0.0 0.137841 0.287855
0.179537 0.0 0.316961 0.0 0.446272 0.0 0.0 0.0 0.0
0.721404 1.4531 0.0 0.0 0.183346 0.0 0.0 1.31021 0.511247
julia> modelgpu(gpu(x))
ERROR: AssertionError: BatchNorm: only track_stats=true supported on gpu
Stacktrace:
[1] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2}, cache::Nothing)
@ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:7
[2] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2})
@ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:6
[3] applychain(fs::Tuple{BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}, x::CUDA.CuArray{Float32, 2}) (repeats 2 times)
@ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36
[4] (::Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 1}}, BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}})(x::CUDA.CuArray{Float32, 2})
@ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38
[5] top-level scope
@ REPL[40]:1
[6] top-level scope
@ ~/.julia/packages/CUDA/3VnCC/src/initialization.jl:81
Julia and package versions:
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
(jl_VSWxj0) pkg> st
Status `/tmp/jl_VSWxj0/Project.toml`
[587475ba] Flux v0.12.3
Any idea why only track_stats=true is supported on GPU?