This simple model training does not work because of the custom activation function:
julia> using CUDA,Flux
julia> CUDA.allowscalar(false)
julia> const c = 100
100
julia> f(x) = σ(x)*2*c
f (generic function with 1 method)
julia> m = Chain(Dense(10,1,f)) |> gpu
Chain(Dense(10, 1, f))
julia> x = rand(10,100) |> gpu
10×100 CuArray{Float32,2,Nothing}:
[...]
julia> y = rand(1,100) |> gpu
1×100 CuArray{Float32,2,Nothing}:
[...]
julia> Flux.train!(Flux.params(m), zip(x,y), ADAM(1f-4)) do x,y
Flux.mse(m(x), y)
end
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] assertscalar(::String) at C:\Users\Henri\.julia\packages\GPUArrays\4W5rW\src\host\indexing.jl:41
[3] getindex at C:\Users\Henri\.julia\packages\GPUArrays\4W5rW\src\host\indexing.jl:96 [inlined]
[4] iterate at .\abstractarray.jl:986 [inlined]
[5] iterate at .\abstractarray.jl:984 [inlined]
[6] _zip_iterate_some at .\iterators.jl:352 [inlined]
[7] _zip_iterate_all at .\iterators.jl:344 [inlined]
[8] iterate at .\iterators.jl:334 [inlined]
[9] macro expansion at C:\Users\Henri\.julia\packages\Juno\tLMZd\src\progress.jl:134 [inlined]
[10] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{CuArray{Float32,2,Nothing},CuArray{Float32,2,Nothing}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at C:\Users\Henri\.julia\packages\Flux\IjMZL\src\optimise\train.jl:80
[11] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{CuArray{Float32,2,Nothing},CuArray{Float32,2,Nothing}}}, ::ADAM) at C:\Users\Henri\.julia\packages\Flux\IjMZL\src\optimise\train.jl:78
[12] top-level scope at REPL[11]:1
[13] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
It does not makes sense because m(x)
works fine yet using σ
instead of f
in the definition of the model works too, so the problem does come from the activation function.
Am I doing something wrong or should I open an issue ? If the latter, is this a Zygote problem or a CUDA problem ? Since the forward propagation is working, I assume it’s the gradient computation that creates the scalar operation.