Scalar getindex during Flux gradient computation

I have a issue with a GPU scalar getindex during gradient computation with Flux. Problem is, I just can’t reproduce the problem with a simpler MWE. I swear I tried so hard. So here’s a chunk of the code and what I understand about the problem.

I have a callable struct which is a neural network with a common body but two heads.

struct Gaussian_Actor{M,S,B}
	body::B
	means::M
	stds::S
end

function Gaussian_Actor(state_size::Int, hidden_size::Int, action_size::Int, activation = relu)
	body = Chain(Dense(state_size, hidden_size, activation),
				Dense(hidden_size, hidden_size, activation)) |> gpu
	means = Chain(Dense(hidden_size, hidden_size, relu),
				Dense(hidden_size, action_size)) |> gpu
	stds = Chain(Dense(hidden_size, hidden_size, relu),
				Dense(hidden_size, action_size, softplus)) |> gpu
	return Gaussian_Actor(body, means, stds)
end

Calling the struct returns a tuple with the output of each head :

function (actor::Gaussian_Actor)(state)
	h = actor.body(state |> gpu)
	return actor.means(h), actor.stds(h)
end

In my loss function, I call this function

function pdf(actor::Gaussian_Actor, state, action)
	μs, σs = actor(state)
	prod(Distributions.normpdf.(μs, σs, action), dims = 1) #works with sum !
end

g(A, clip) = A .* (2 * clip .* Flux.Zygote.dropgrad(A .> 0) .+ (1 - clip))


function L_clip(state, action, advantage, actor, actor_old, clip) #state action and advantage are cuarrays, actor and actor_old are gaussian_actors and clip is a scalar.
	ratio = pdf(actor, state, action)./pdf(actor_old, state, action)
	return mean(min.(ratio.* advantage, g(advantage, clip)))
end

And when I call Flux.gradient I get this scalar getindex error. Here I call the full loss but I’ll explain why I know we only need the pdf function.

julia> Flux.gradient(()->L_clip(states |> gpu, actions |> gpu, advantages |> gpu, actor, actor_old, clip), Flux.params(actor))
ERROR: scalar getindex is disallowed 
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] assertscalar(::String) at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:41
 [3] getindex at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:96 [inlined]
 [4] _getindex at .\abstractarray.jl:1083 [inlined]
 [5] getindex at .\abstractarray.jl:1060 [inlined]
 [6] getindex at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\adjtrans.jl:190 [inlined]
 [7] _unsafe_getindex_rs at .\reshapedarray.jl:249 [inlined]
 [8] _unsafe_getindex at .\reshapedarray.jl:246 [inlined]
 [9] getindex at .\reshapedarray.jl:234 [inlined]
 [10] _getindex at .\abstractarray.jl:1100 [inlined]
 [11] getindex at .\abstractarray.jl:1060 [inlined]
 [12] _broadcast_getindex at .\broadcast.jl:614 [inlined]
 [13] _getindex at .\broadcast.jl:644 [inlined]
 [14] _broadcast_getindex at .\broadcast.jl:620 [inlined]
 [15] getindex at .\broadcast.jl:575 [inlined]
 [16] macro expansion at .\broadcast.jl:932 [inlined]
 [17] macro expansion at .\simdloop.jl:77 [inlined]
 [18] copyto! at .\broadcast.jl:931 [inlined]
 [19] copyto! at .\broadcast.jl:886 [inlined]
 [20] copy at .\broadcast.jl:862 [inlined]
 [21] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2},Nothing,Zygote.var"#1153#1160"{Zygote.Context,typeof(g)},Tuple{Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}},Float64}}) at .\broadcast.jl:837
 [22] _broadcast(::Zygote.var"#1153#1160"{Zygote.Context,typeof(g)}, ::Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Vararg{Any,N} where N) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\broadcast.jl:129
 [23] adjoint at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\broadcast.jl:138 [inlined]
 [24] _pullback at C:\Users\Henri\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [25] adjoint at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\lib.jl:175 [inlined]
 [26] _pullback at C:\Users\Henri\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [27] broadcasted at .\broadcast.jl:1263 [inlined]
 [28] L_clip at c:\Users\Henri\OneDrive - UCL\Doctorat\DRL_SIP\src\PPO.jl:57 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(L_clip), ::CuArray{Float32,2}, ::CuArray{Float32,2}, ::Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Float64) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [30] #68 at .\REPL[112]:1 [inlined]
 [31] _pullback(::Zygote.Context, ::var"#68#69") at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [32] pullback(::Function, ::Zygote.Params) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:172
 [33] gradient(::Function, ::Zygote.Params) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:53
 [34] top-level scope at REPL[112]:1
 [35] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088

Here a few things that make the train loop work (I mean, they all work independently):

  • when I replace prod with sum(..., dims = 1) in pdf, my train loop works BUT NOT the Flux.gradient() function
  • remove , g(advantage, clip) from L_clip
  • remove ratio.* advantage, from L_clip