I have a issue with a GPU scalar getindex during gradient computation with Flux. Problem is, I just can’t reproduce the problem with a simpler MWE. I swear I tried so hard. So here’s a chunk of the code and what I understand about the problem.
I have a callable struct which is a neural network with a common body but two heads.
struct Gaussian_Actor{M,S,B}
body::B
means::M
stds::S
end
function Gaussian_Actor(state_size::Int, hidden_size::Int, action_size::Int, activation = relu)
body = Chain(Dense(state_size, hidden_size, activation),
Dense(hidden_size, hidden_size, activation)) |> gpu
means = Chain(Dense(hidden_size, hidden_size, relu),
Dense(hidden_size, action_size)) |> gpu
stds = Chain(Dense(hidden_size, hidden_size, relu),
Dense(hidden_size, action_size, softplus)) |> gpu
return Gaussian_Actor(body, means, stds)
end
Calling the struct returns a tuple with the output of each head :
function (actor::Gaussian_Actor)(state)
h = actor.body(state |> gpu)
return actor.means(h), actor.stds(h)
end
In my loss function, I call this function
function pdf(actor::Gaussian_Actor, state, action)
μs, σs = actor(state)
prod(Distributions.normpdf.(μs, σs, action), dims = 1) #works with sum !
end
g(A, clip) = A .* (2 * clip .* Flux.Zygote.dropgrad(A .> 0) .+ (1 - clip))
function L_clip(state, action, advantage, actor, actor_old, clip) #state action and advantage are cuarrays, actor and actor_old are gaussian_actors and clip is a scalar.
ratio = pdf(actor, state, action)./pdf(actor_old, state, action)
return mean(min.(ratio.* advantage, g(advantage, clip)))
end
And when I call Flux.gradient I get this scalar getindex error. Here I call the full loss but I’ll explain why I know we only need the pdf
function.
julia> Flux.gradient(()->L_clip(states |> gpu, actions |> gpu, advantages |> gpu, actor, actor_old, clip), Flux.params(actor))
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] assertscalar(::String) at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:41
[3] getindex at C:\Users\Henri\.julia\packages\GPUArrays\PkHCM\src\host\indexing.jl:96 [inlined]
[4] _getindex at .\abstractarray.jl:1083 [inlined]
[5] getindex at .\abstractarray.jl:1060 [inlined]
[6] getindex at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\adjtrans.jl:190 [inlined]
[7] _unsafe_getindex_rs at .\reshapedarray.jl:249 [inlined]
[8] _unsafe_getindex at .\reshapedarray.jl:246 [inlined]
[9] getindex at .\reshapedarray.jl:234 [inlined]
[10] _getindex at .\abstractarray.jl:1100 [inlined]
[11] getindex at .\abstractarray.jl:1060 [inlined]
[12] _broadcast_getindex at .\broadcast.jl:614 [inlined]
[13] _getindex at .\broadcast.jl:644 [inlined]
[14] _broadcast_getindex at .\broadcast.jl:620 [inlined]
[15] getindex at .\broadcast.jl:575 [inlined]
[16] macro expansion at .\broadcast.jl:932 [inlined]
[17] macro expansion at .\simdloop.jl:77 [inlined]
[18] copyto! at .\broadcast.jl:931 [inlined]
[19] copyto! at .\broadcast.jl:886 [inlined]
[20] copy at .\broadcast.jl:862 [inlined]
[21] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2},Nothing,Zygote.var"#1153#1160"{Zygote.Context,typeof(g)},Tuple{Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}},Float64}}) at .\broadcast.jl:837
[22] _broadcast(::Zygote.var"#1153#1160"{Zygote.Context,typeof(g)}, ::Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Vararg{Any,N} where N) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\broadcast.jl:129
[23] adjoint at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\broadcast.jl:138 [inlined]
[24] _pullback at C:\Users\Henri\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
[25] adjoint at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\lib\lib.jl:175 [inlined]
[26] _pullback at C:\Users\Henri\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
[27] broadcasted at .\broadcast.jl:1263 [inlined]
[28] L_clip at c:\Users\Henri\OneDrive - UCL\Doctorat\DRL_SIP\src\PPO.jl:57 [inlined]
[29] _pullback(::Zygote.Context, ::typeof(L_clip), ::CuArray{Float32,2}, ::CuArray{Float32,2}, ::Base.ReshapedArray{Float32,2,Adjoint{Float32,CuArray{Float32,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Gaussian_Actor{Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(softplus),CuArray{Float32,2},CuArray{Float32,1}}}},Chain{Tuple{Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}},Dense{typeof(relu),CuArray{Float32,2},CuArray{Float32,1}}}}}, ::Float64) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
[30] #68 at .\REPL[112]:1 [inlined]
[31] _pullback(::Zygote.Context, ::var"#68#69") at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
[32] pullback(::Function, ::Zygote.Params) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:172
[33] gradient(::Function, ::Zygote.Params) at C:\Users\Henri\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:53
[34] top-level scope at REPL[112]:1
[35] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
Here a few things that make the train loop work (I mean, they all work independently):
- when I replace
prod
withsum(..., dims = 1)
inpdf
, my train loop works BUT NOT theFlux.gradient()
function - remove
, g(advantage, clip)
fromL_clip
- remove
ratio.* advantage,
fromL_clip