There is a code for simple RNN network
using Flux
rnn = Flux.RNN(2, 3)
seq = [rand(Float32,2) for i = 1:3]
y = [ones(Float32, 3) for i=1:3]
function loss(x, y)
a = flat_this((rnn.(x) .- y))
_loss = sum(abs2, a)
end
function flat_this(x)
arr = ones(0)
for i=1:length(x)
arr = vcat(arr, x[i])
end
return arr
end
Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params(rnn)) do
sum(loss(seq,y))
end
for p in grads_seq.grads
println("$p\n")
end
Output is the following:
Pair{Any,Any}(Float32[0.0; 0.0; 0.0], [3.7878030195245587; 1.612768501977888; -2.335516551451906])
Pair{Any,Any}(:(Main.rnn), Base.RefValue{Any}((cell = nothing, state = nothing)))
Pair{Any,Any}(:(Main.y), [[1.8429752588272095, 0.6448161602020264, 2.7944788932800293], [1.6635382175445557, 1.508173942565918, 3.3710970878601074], [3.514638900756836, 1.2005153894424438, 3.7064309120178223]])
Pair{Any,Any}(Float32[0.9722849 -0.80289185; 0.85602915 0.34971967; -0.20451058 -0.39597183], [0.9432474362031119 -0.07124584719900495; -0.6602998565304186 -1.1630237278179147; -4.478283682945728 -4.956677223884434])
Pair{Any,Any}(Float32[0.0, 0.0, 0.0], [0.4786334111414894, -1.9622023618757174, -8.015419969832937])
Pair{Any,Any}(:(Main.seq), [[2.9819661075480615, 0.3152601911834616], [-0.44246049922761166, 0.5381940409583041], [-2.1146017665440304, 1.249914684580733]])
Pair{Any,Any}(Recur(RNNCell(2, 3, tanh)), Base.RefValue{Any}((cell = nothing, state = nothing)))
Pair{Any,Any}(Float32[-0.07491565 0.4195478 0.795933; -0.6760104 -0.46338344 -0.15896535; -0.9976423 -0.2424345 0.89184976], [-0.2417448859956066 -0.27872805151029967 0.9748700622226689; -0.2748194556314892 -1.1553449387338601 1.2234039848773337; -0.3902901267103135 -2.15243746698388 1.8077103083289559])
How to understand which weight each gradient are bound to?