I am currently working on a Reinforcement Learning problem, which requires to form and store (matrices) of gradients and then update the parameters of the Neural Network manually.
Basically, what I have done is take the gradient of a loss function with respect to the parameters of a Neural Network and then reshape this gradient to a column vector (which needed to be done for my purposes). I redo this a couple of times and end up with a vector w that I want to update, i.e. my update rule is
θ = θ + α*w
where θ denote the parameters of the network. Due to the way Flux works, the parameters are stored in some kind of special structure, so I do not necessarily have acess to the parameters as a single vector.
The way I obained the gradients as a vector is:
function gradcollector(gs)
gradients = []
for i in 1:length(collect(gs.grads))
push!(gradients,collect(gs.grads)[i][2])
end
return collect(Iterators.flatten(gradients))
end
where gs is obtained by doing something like
ps = Flux.params(model)
gs = Flux.gradient(() -> loss(input), ps)
I tried to manually reshape w into the form that the parameters have, by iterating through the shapes of the gradients and then reshaping each respective bit into a matrix, but I cannot even do that, since the order in which the parameter matrices are stored is different than the order which I obtain by performing a gradient.
Any help on how to fix this problem would be greatly appreciated.
I might attempt at some implementation of packing/unpacking the Zygote.Grads type, which perhaps would be well received in the Flux repository.
If you made any strides forward on this problem I would be interested in learning from your experience. I will post here once I have a working solution.
using Zygote, Test
gradlength(grads) = sum(length(g[1]) for g in grads.grads)
function flatten!(gradvec, grads::Zygote.Grads)
@assert length(gradvec) == gradlength(grads)
s = 1
for g in grads.grads
l = length(g[2])
gradvec[s:s+l-1] .= vec(g[2])
s += l
end
gradvec
end
flatten(grads::Zygote.Grads) = flatten!(zeros(gradlength(grads)), grads::Zygote.Grads)
function unflatten!(gradvec, grads::Zygote.Grads)
s = 1
for g in grads.grads
l = length(g[2])
g[2] .= reshape(gradvec[s:s+l-1], size(g[2]))
s += l
end
grads
end
@test unflatten!(flatten(grads), grads) == grads
This seems like the kind of problem Flatten.jl is designed to solve… However it wont work for arrays as its all compile-time generation and we don’t know the length of most arrays in @generated.
It should be able to generate very fast code for Static arrays if you add the approropriate constructor_of() method.
There are some updates on master, but I’m not sure if they have made it to Pkg yet I’ve been preoccupied.
(Nested is not really a package anymore it got refactored so small I just made it part of Flatten)