Zygote + CUDA: scalar getindex with custom activation function using multiplication

This simple model training does not work because of the custom activation function:

julia> using CUDA,Flux

julia> CUDA.allowscalar(false)

julia> const c = 100

julia> f(x) = σ(x)*2*c
f (generic function with 1 method)

julia> m = Chain(Dense(10,1,f)) |> gpu
Chain(Dense(10, 1, f))

julia> x = rand(10,100) |> gpu
10×100 CuArray{Float32,2,Nothing}:

julia> y = rand(1,100) |> gpu
1×100 CuArray{Float32,2,Nothing}:

julia> Flux.train!(Flux.params(m), zip(x,y), ADAM(1f-4)) do x,y
       Flux.mse(m(x), y)
ERROR: scalar getindex is disallowed
 [1] error(::String) at .\error.jl:33
 [2] assertscalar(::String) at C:\Users\Henri\.julia\packages\GPUArrays\4W5rW\src\host\indexing.jl:41
 [3] getindex at C:\Users\Henri\.julia\packages\GPUArrays\4W5rW\src\host\indexing.jl:96 [inlined]
 [4] iterate at .\abstractarray.jl:986 [inlined]
 [5] iterate at .\abstractarray.jl:984 [inlined]
 [6] _zip_iterate_some at .\iterators.jl:352 [inlined]
 [7] _zip_iterate_all at .\iterators.jl:344 [inlined]
 [8] iterate at .\iterators.jl:334 [inlined]
 [9] macro expansion at C:\Users\Henri\.julia\packages\Juno\tLMZd\src\progress.jl:134 [inlined]
 [10] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{CuArray{Float32,2,Nothing},CuArray{Float32,2,Nothing}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at C:\Users\Henri\.julia\packages\Flux\IjMZL\src\optimise\train.jl:80
 [11] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{CuArray{Float32,2,Nothing},CuArray{Float32,2,Nothing}}}, ::ADAM) at C:\Users\Henri\.julia\packages\Flux\IjMZL\src\optimise\train.jl:78
 [12] top-level scope at REPL[11]:1
 [13] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088

It does not makes sense because m(x) works fine yet using σ instead of f in the definition of the model works too, so the problem does come from the activation function.

Am I doing something wrong or should I open an issue ? If the latter, is this a Zygote problem or a CUDA problem ? Since the forward propagation is working, I assume it’s the gradient computation that creates the scalar operation.

Do you see it now?

What you need to change
Flux.train!(Flux.params(m), [(x, y)], ADAM(1f-4)) do x, y
    Flux.mse(m(x), y)

I would also suggest that you put the code here in a way that is friendly to copy and paste. The more effort you put in the more likely and sooner someone can help you.

Ha thanks. Sadly, I did not make that mistake in my actual use case. So I’m just left with a MWE that does not reproduce my error… I’ll come up with another.

Solved it. The issue was caused because the activation function returned a Float64 instead of a Float32. The scalar getindex issue was masking a no method matching error.