BoundsError in Zygote RNN gradient

Hi everyone,

I’ve experienced some problems with Zygote regarding the computation of gradients of a loss function obtained via sampling. An RNN generates samples autoregressively and the loss function is then computed using the samples. Unfortunately, trying to compute the gradient of the function with respect to the model parameters leads to a ‘‘BoundsError: attempt to access 0-element Vector{Any} at index ’’. Does anybody have any idea about what could be happening here?

Thanks in advance for any help!

using Flux
using StatsBase

d_h = 4
m = Chain(GRUv3(1 => d_h, init=Flux.truncated_normal(mean = 0, std = 1, lo = -2, hi = 2)), Dense(d_h => 2))
length = 4
N_s = 10000

function autoreg_sample(model, length)
    x_t = 0
    sample = []
    for t in 1:length
        output = model([x_t])
        probs = softmax(output)
        x_t = StatsBase.sample([0, 1], Weights([probs[1], probs[2]]))
        sample = [sample; x_t]
    end 
    Flux.reset!(model)
    return sample
end 

function f(x)
    return sum(x)
end 

function loss(model, length, N_s)
    result = Float64[]
    for i = 1:N_s
        s1 = autoreg_sample(model, length)
        s2 = autoreg_sample(model, length)
        result = [result; f(s1)*f(s2)]
    end 
    return mean(result)
end 

grads = Flux.gradient(model -> loss(model, length, N_s), m)
ERROR: BoundsError: attempt to access 0-element Vector{Any} at index []
Stacktrace:
  [1] throw_boundserror(A::Vector{Any}, I::Tuple{})
    @ Base ./abstractarray.jl:703
  [2] checkbounds
    @ ./abstractarray.jl:668 [inlined]
  [3] _getindex
    @ ./abstractarray.jl:1273 [inlined]
  [4] getindex
    @ ./abstractarray.jl:1241 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:314 [inlined]
  [6] (::Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:308
  [7] (::Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [8] Pullback
    @ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:13 [inlined]
  [9] (::Zygote.Pullback{Tuple{Type{Weights{Float32, Float32, Vector{Float32}}}, Vector{Float32}, Float32}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#330"}, Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float32}}, Vector{Float32}}, Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:13 [inlined]
 [11] Pullback
    @ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:16 [inlined]
 [12] (::Zygote.Pullback{Tuple{Type{Weights}, Vector{Float32}}, Tuple{Zygote.var"#2971#back#768"{Zygote.var"#762#766"{Vector{Float32}}}, Zygote.Pullback{Tuple{Type{Weights}, Vector{Float32}, Float32}, Tuple{Zygote.Pullback{Tuple{Type{Weights{Float32, Float32, Vector{Float32}}}, Vector{Float32}, Float32}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#330"}, Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float32}}, Vector{Float32}}, Any}}}}}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/Desktop/RNNWavefunctions/other_min.jl:15 [inlined]
 [14] (::Zygote.Pullback{Tuple{typeof(autoreg_sample), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/Desktop/RNNWavefunctions/other_min.jl:30 [inlined]
 [16] (::Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/Desktop/RNNWavefunctions/other_min.jl:36 [inlined]
 [18] (::Zygote.Pullback{Tuple{var"#7#8", Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#7#8", Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [20] gradient(f::Function, args::Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [21] top-level scope
    @ ~/Desktop/RNNWavefunctions/other_min.jl:36