Not able to use BatchNorm
with track_stats=false
on GPU.
julia> using Flux
julia> model = Chain(Dense(10,5),BatchNorm(5,relu; track_stats=false))
Chain(Dense(10, 5), BatchNorm(5, relu))
julia> modelgpu = gpu(model)
Chain(Dense(10, 5), BatchNorm(5, relu))
julia> x = randn(Float32, 10, 100);
julia> model(x)
5×100 Matrix{Float32}:
0.897995 1.41563 0.0 0.625536 0.0 … 0.358508 0.302867 0.0 0.0
0.903481 0.0 0.782808 0.673593 0.0 1.81217 0.277665 0.0 0.0
1.37133 0.0 1.94733 0.259229 0.0 1.73941 0.0 0.137841 0.287855
0.179537 0.0 0.316961 0.0 0.446272 0.0 0.0 0.0 0.0
0.721404 1.4531 0.0 0.0 0.183346 0.0 0.0 1.31021 0.511247
julia> modelgpu(gpu(x))
ERROR: AssertionError: BatchNorm: only track_stats=true supported on gpu
Stacktrace:
[1] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2}, cache::Nothing)
@ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:7
[2] (::BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing})(x::CUDA.CuArray{Float32, 2})
@ Flux.CUDAint ~/.julia/packages/Flux/6o4DQ/src/cuda/cudnn.jl:6
[3] applychain(fs::Tuple{BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}, x::CUDA.CuArray{Float32, 2}) (repeats 2 times)
@ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36
[4] (::Chain{Tuple{Dense{typeof(identity), CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 1}}, BatchNorm{typeof(relu), CUDA.CuArray{Float32, 1}, Float32, Nothing}}})(x::CUDA.CuArray{Float32, 2})
@ Flux ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38
[5] top-level scope
@ REPL[40]:1
[6] top-level scope
@ ~/.julia/packages/CUDA/3VnCC/src/initialization.jl:81
Julia and package versions:
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
(jl_VSWxj0) pkg> st
Status `/tmp/jl_VSWxj0/Project.toml`
[587475ba] Flux v0.12.3
Any idea why only track_stats=true
is supported on GPU?