Flux gradients intepretation

There is a code for simple RNN network

using Flux

rnn = Flux.RNN(2, 3)
seq = [rand(Float32,2) for i = 1:3]
y = [ones(Float32, 3) for i=1:3]

function loss(x, y)
    a = flat_this((rnn.(x) .- y))
   _loss = sum(abs2, a)
end

function flat_this(x)
    arr = ones(0)
    for i=1:length(x)
        arr = vcat(arr, x[i])
    end
    return arr
end

Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params(rnn)) do
    sum(loss(seq,y))
end

for p in grads_seq.grads
    println("$p\n")
end

Output is the following:

Pair{Any,Any}(Float32[0.0; 0.0; 0.0], [3.7878030195245587; 1.612768501977888; -2.335516551451906])

Pair{Any,Any}(:(Main.rnn), Base.RefValue{Any}((cell = nothing, state = nothing)))

Pair{Any,Any}(:(Main.y), [[1.8429752588272095, 0.6448161602020264, 2.7944788932800293], [1.6635382175445557, 1.508173942565918, 3.3710970878601074], [3.514638900756836, 1.2005153894424438, 3.7064309120178223]])

Pair{Any,Any}(Float32[0.9722849 -0.80289185; 0.85602915 0.34971967; -0.20451058 -0.39597183], [0.9432474362031119 -0.07124584719900495; -0.6602998565304186 -1.1630237278179147; -4.478283682945728 -4.956677223884434])

Pair{Any,Any}(Float32[0.0, 0.0, 0.0], [0.4786334111414894, -1.9622023618757174, -8.015419969832937])

Pair{Any,Any}(:(Main.seq), [[2.9819661075480615, 0.3152601911834616], [-0.44246049922761166, 0.5381940409583041], [-2.1146017665440304, 1.249914684580733]])

Pair{Any,Any}(Recur(RNNCell(2, 3, tanh)), Base.RefValue{Any}((cell = nothing, state = nothing)))

Pair{Any,Any}(Float32[-0.07491565 0.4195478 0.795933; -0.6760104 -0.46338344 -0.15896535; -0.9976423 -0.2424345 0.89184976], [-0.2417448859956066 -0.27872805151029967 0.9748700622226689; -0.2748194556314892 -1.1553449387338601 1.2234039848773337; -0.3902901267103135 -2.15243746698388 1.8077103083289559])

How to understand which weight each gradient are bound to?

You can index the Grads struct returned from gradient with the parameter arrays in params(rnn) (i.e. your model’s parameters) to get the gradient for a specific parameter. e.g. grads_seq[rnn.cell.Wi].

1 Like