How to obtain gradients from training a model

Cheers.

When training a model, I want to investigate on the occurrence of exploding gradients. For this reason, I am using the following code to obtain the gradient values when training one epoch:

losses = Float32[]
allgrads = Any[]

# for (i, data) in enumerate(trainset)
for data in trainset
    X, y = data

    val, grads = Flux.withgradient(model) do m
        result = m(X)
        lossFunction(result, y)
    end

    push!(losses, val)
    push!(allgrads, grads)

    Flux.update!(optimizerState, model, grads[1])
end

The vector “allgrads” store gradients from all minibatches on it. For instance, one vector element looks like:

((layers = ((layers = ((layers = ((layers = ((layers = ((σ = nothing, weight = [0.0006835972 6.486452f-5 -0.0006134349; 0.000548505 0.000153064 -0.000340064; 0.0003190739 5.5526354f-5 -0.00022200258;;; 0.00085182954 0.00029911275 -0.00045436312; 0.0006824782 0.00027364268 -0.00028414163; 0.00041472033 0.0002219973 -0.00023666112;;; 0.00092873315 0.00046966688 -0.00031694077; 0.0007362623 0.0004371706 -0.00015576907; 0.0005101147 0.00034959422 0.00023152842; 0.00019810414 0.0003540415 0.00025109033], bias = Float32[-6.82121f-13, -9.822543f-11, -5.820766f-11, -1.1641532f-10, -1.8189894f-12, 2.0372681f-10, 3.637979f-12, 5.0931703f-11], stride = nothing, pad = nothing, dilation = nothing, groups = nothing), (λ = nothing, β = Float32[4.887536f-5, -0.006509954, -0.0034560864, 0.0026293392, 0.00068988354, 0.0006955422, 0.00079482584, 0.0012673868], γ = Float32[-3.69288f-5, -0.00818499, -0.0013030381, 0.004327028, 0.0013199829, 0.0031451036, 0.00021715634, -0.0011249712], μ = nothing, σ² = nothing, ϵ = nothing, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing)),)),), (σ = nothing, weight = [0.0015703159;;; 0.0094113285;;; -0.0029107195;;; 0.006112651;;; 0.0075858478;;; 0.0015351735;;; -0.0013515109;;; 0.0019647335;;;;], bias = Float32[-0.007379764], stride = nothing, pad = nothing, dilation = nothing, groups = nothing)),)),), nothing),),)

As the structure looks a bit confusing, I wonder if someone could please indicate a simple way to interpret and extract biases and weights.

Of course, any other method to achieve this goal is also welcome.

Thanks in advance.

model and grads[1] are trees with the same nesting structure, and the same field names. Except that model uses custom structs like Dense, while grads uses anonymous ones, NamedTuples.

Making a smaller example, here is how you can explore the two:

julia> model = Chain(Dense(2=>1), SkipConnection(Dense(1=>1),+))
Chain(
  Dense(2 => 1),                        # 3 parameters
  SkipConnection(
    Dense(1 => 1),                      # 2 parameters
    +,
  ),
)                   # Total: 4 arrays, 5 parameters, 292 bytes.

julia> grads = gradient(m -> sum(abs2, m([1,-1])), model)
((layers = ((weight = Float32[-2.6218274 2.6218274], bias = Float32[-2.6218274], σ = nothing), (layers = (weight = Float32[0.8607899;;], bias = Float32[-1.6526356], σ = nothing), connection = nothing)),),)

julia> model.layers[1]
Dense(2 => 1)       # 3 parameters

julia> model.layers[1].weight  # pressing tab will show you field names as you type
1×2 Matrix{Float32}:
 -0.675822  -0.154963

julia> model.layers[1].bias  # initialised to zero
1-element Vector{Float32}:
 0.0

julia> grads[1].layers[1]  # corresponding to Dense
(weight = Float32[-2.6218274 2.6218274], bias = Float32[-2.6218274], σ = nothing)

julia> grads[1].layers[1].weight
1×2 Matrix{Float32}:
 -2.62183  2.62183

julia> grads[1].layers[1].bias
1-element Vector{Float32}:
 -2.6218274

One catch is that model[2] also works, the same as model.layers[2], but won’t work on the gradient: grads[1][2] is an error. (Indexing a Chain indexes the tuple inside, but won’t work this way on a NamedTuple.)

(They aren’t always strictly trees, the same object can appear twice, but usually they are.)