Manually updating the parameters of a Neural Network in Flux

I am currently working on a Reinforcement Learning problem, which requires to form and store (matrices) of gradients and then update the parameters of the Neural Network manually.
Basically, what I have done is take the gradient of a loss function with respect to the parameters of a Neural Network and then reshape this gradient to a column vector (which needed to be done for my purposes). I redo this a couple of times and end up with a vector w that I want to update, i.e. my update rule is

θ = θ + α*w

where θ denote the parameters of the network. Due to the way Flux works, the parameters are stored in some kind of special structure, so I do not necessarily have acess to the parameters as a single vector.

The way I obained the gradients as a vector is:

function gradcollector(gs)
    gradients = []
    for i in 1:length(collect(gs.grads))
    return collect(Iterators.flatten(gradients))

where gs is obtained by doing something like

ps = Flux.params(model)

gs = Flux.gradient(() -> loss(input), ps)

I tried to manually reshape w into the form that the parameters have, by iterating through the shapes of the gradients and then reshaping each respective bit into a matrix, but I cannot even do that, since the order in which the parameter matrices are stored is different than the order which I obtain by performing a gradient.

Any help on how to fix this problem would be greatly appreciated.


I am also interested in this, and I believe many others would be as well. I list some packages that target this particular problem below for reference

I might attempt at some implementation of packing/unpacking the Zygote.Grads type, which perhaps would be well received in the Flux repository.

If you made any strides forward on this problem I would be interested in learning from your experience. I will post here once I have a working solution.

Here’s my attempt at the Zygote.Grads type

using Zygote, Test

gradlength(grads) = sum(length(g[1]) for g in grads.grads)

function flatten!(gradvec, grads::Zygote.Grads)
    @assert length(gradvec) == gradlength(grads)
    s = 1
    for g in grads.grads
        l = length(g[2])
        gradvec[s:s+l-1] .= vec(g[2])
        s += l

flatten(grads::Zygote.Grads) = flatten!(zeros(gradlength(grads)), grads::Zygote.Grads)

function unflatten!(gradvec, grads::Zygote.Grads)
    s = 1
    for g in grads.grads
        l = length(g[2])
        g[2] .= reshape(gradvec[s:s+l-1], size(g[2]))
        s += l

@test unflatten!(flatten(grads), grads) == grads
1 Like

This seems like the kind of problem Flatten.jl is designed to solve… However it wont work for arrays as its all compile-time generation and we don’t know the length of most arrays in @generated.

It should be able to generate very fast code for Static arrays if you add the approropriate constructor_of() method.

There are some updates on master, but I’m not sure if they have made it to Pkg yet I’ve been preoccupied.

(Nested is not really a package anymore it got refactored so small I just made it part of Flatten)