Scalar getindex during Flux gradient computation

I have a issue with a GPU scalar getindex during gradient computation with Flux. Problem is, I just can’t reproduce the problem with a simpler MWE. I swear I tried so hard. So here’s a chunk of the code and what I understand about the problem.

I have a callable struct which is a neural network with a common body but two heads.

``````struct Gaussian_Actor{M,S,B}
body::B
means::M
stds::S
end

function Gaussian_Actor(state_size::Int, hidden_size::Int, action_size::Int, activation = relu)
body = Chain(Dense(state_size, hidden_size, activation),
Dense(hidden_size, hidden_size, activation)) |> gpu
means = Chain(Dense(hidden_size, hidden_size, relu),
Dense(hidden_size, action_size)) |> gpu
stds = Chain(Dense(hidden_size, hidden_size, relu),
Dense(hidden_size, action_size, softplus)) |> gpu
return Gaussian_Actor(body, means, stds)
end
``````

Calling the struct returns a tuple with the output of each head :

``````function (actor::Gaussian_Actor)(state)
h = actor.body(state |> gpu)
return actor.means(h), actor.stds(h)
end
``````

In my loss function, I call this function

``````function pdf(actor::Gaussian_Actor, state, action)
μs, σs = actor(state)
prod(Distributions.normpdf.(μs, σs, action), dims = 1) #works with sum !
end

g(A, clip) = A .* (2 * clip .* Flux.Zygote.dropgrad(A .> 0) .+ (1 - clip))

function L_clip(state, action, advantage, actor, actor_old, clip) #state action and advantage are cuarrays, actor and actor_old are gaussian_actors and clip is a scalar.
ratio = pdf(actor, state, action)./pdf(actor_old, state, action)
end
``````

And when I call Flux.gradient I get this scalar getindex error. Here I call the full loss but I’ll explain why I know we only need the `pdf` function.

``````julia> Flux.gradient(()->L_clip(states |> gpu, actions |> gpu, advantages |> gpu, actor, actor_old, clip), Flux.params(actor))
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] assertscalar(::String) at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:41
[3] getindex at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:96 [inlined]
[4] _getindex at .\abstractarray.jl:1083 [inlined]
[5] getindex at .\abstractarray.jl:1060 [inlined]
[7] _unsafe_getindex_rs at .\reshapedarray.jl:249 [inlined]
[8] _unsafe_getindex at .\reshapedarray.jl:246 [inlined]
[9] getindex at .\reshapedarray.jl:234 [inlined]
[10] _getindex at .\abstractarray.jl:1100 [inlined]
[11] getindex at .\abstractarray.jl:1060 [inlined]
[16] macro expansion at .\broadcast.jl:932 [inlined]
[17] macro expansion at .\simdloop.jl:77 [inlined]
[28] L_clip at c:\Users\Henri\OneDrive - UCL\Doctorat\DRL_SIP\src\PPO.jl:57 [inlined]
[29] _pullback(::Zygote.Context, ::typeof(L_clip), ::CuArray{Float32,2}, ::CuArray{Float32,2}, ::Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Float64) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
[30] #68 at .\REPL[112]:1 [inlined]
[31] _pullback(::Zygote.Context, ::var"#68#69") at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
[32] pullback(::Function, ::Zygote.Params) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:172
• when I replace `prod` with `sum(..., dims = 1)` in `pdf`, my train loop works BUT NOT the `Flux.gradient()` function
• remove `, g(advantage, clip)` from `L_clip`
• remove `ratio.* advantage, ` from `L_clip`