Hi everyone,
I’ve experienced some problems with Zygote regarding the computation of gradients of a loss function obtained via sampling. An RNN generates samples autoregressively and the loss function is then computed using the samples. Unfortunately, trying to compute the gradient of the function with respect to the model parameters leads to a ‘‘BoundsError: attempt to access 0-element Vector{Any} at index ’’. Does anybody have any idea about what could be happening here?
Thanks in advance for any help!
using Flux
using StatsBase
d_h = 4
m = Chain(GRUv3(1 => d_h, init=Flux.truncated_normal(mean = 0, std = 1, lo = -2, hi = 2)), Dense(d_h => 2))
length = 4
N_s = 10000
function autoreg_sample(model, length)
x_t = 0
sample = []
for t in 1:length
output = model([x_t])
probs = softmax(output)
x_t = StatsBase.sample([0, 1], Weights([probs[1], probs[2]]))
sample = [sample; x_t]
end
Flux.reset!(model)
return sample
end
function f(x)
return sum(x)
end
function loss(model, length, N_s)
result = Float64[]
for i = 1:N_s
s1 = autoreg_sample(model, length)
s2 = autoreg_sample(model, length)
result = [result; f(s1)*f(s2)]
end
return mean(result)
end
grads = Flux.gradient(model -> loss(model, length, N_s), m)
ERROR: BoundsError: attempt to access 0-element Vector{Any} at index []
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{})
@ Base ./abstractarray.jl:703
[2] checkbounds
@ ./abstractarray.jl:668 [inlined]
[3] _getindex
@ ./abstractarray.jl:1273 [inlined]
[4] getindex
@ ./abstractarray.jl:1241 [inlined]
[5] macro expansion
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:314 [inlined]
[6] (::Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:308
[7] (::Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[8] Pullback
@ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:13 [inlined]
[9] (::Zygote.Pullback{Tuple{Type{Weights{Float32, Float32, Vector{Float32}}}, Vector{Float32}, Float32}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#330"}, Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float32}}, Vector{Float32}}, Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:13 [inlined]
[11] Pullback
@ ~/.julia/packages/StatsBase/XgjIN/src/weights.jl:16 [inlined]
[12] (::Zygote.Pullback{Tuple{Type{Weights}, Vector{Float32}}, Tuple{Zygote.var"#2971#back#768"{Zygote.var"#762#766"{Vector{Float32}}}, Zygote.Pullback{Tuple{Type{Weights}, Vector{Float32}, Float32}, Tuple{Zygote.Pullback{Tuple{Type{Weights{Float32, Float32, Vector{Float32}}}, Vector{Float32}, Float32}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#330"}, Zygote.var"#2193#back#313"{Zygote.Jnew{Weights{Float32, Float32, Vector{Float32}}, Vector{Any}, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float32}}, Vector{Float32}}, Any}}}}}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[13] Pullback
@ ~/Desktop/RNNWavefunctions/other_min.jl:15 [inlined]
[14] (::Zygote.Pullback{Tuple{typeof(autoreg_sample), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[15] Pullback
@ ~/Desktop/RNNWavefunctions/other_min.jl:30 [inlined]
[16] (::Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[17] Pullback
@ ~/Desktop/RNNWavefunctions/other_min.jl:36 [inlined]
[18] (::Zygote.Pullback{Tuple{var"#7#8", Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#7#8", Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{Zygote.Pullback{Tuple{typeof(loss), Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Int64, Int64}, Any}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}, Zygote.var"#1972#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Int64}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[20] gradient(f::Function, args::Chain{Tuple{Flux.Recur{Flux.GRUv3Cell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[21] top-level scope
@ ~/Desktop/RNNWavefunctions/other_min.jl:36