 # Flux.params of a matrix implemented as a struct

Dear All,

I have a student who is implementing few algorithms for distributed optimization (training of neural networks) using Flux / Zygote tooling. We have been tracking various sources of inefficiencies and I have found out that manipulating (adding) gradients stored in `Zygote.Grads` (essentially an IdDict) poses quite some overhead. Therefore I got an idea, that I will pre-allocate a large vector for all gradients, such that they are stored in a single continuous vector and initiate `Grads` with views into this array (see MWE below)

``````using Flux, Zygote, BenchmarkTools
using Zygote: Context, _pullback, cache, sensitivity, Grads

function initgrad(ps)
l = mapreduce(length, +, ps)
vec_ps = zeros(Float32, l)
start = 1
cache = IdDict{Any,Any}()
for p in ps
stop = start + length(p) - 1
cache[p] = reshape(view(vec_ps, start:stop), size(p))
start = stop + 1
end
cache
Grads(cache, ps), vec_ps
end

function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
cx = Context(gs.grads)
for p in ps
cache(cx)[p] .= 0
end
y, back = _pullback(cx, f)
back(sensitivity(y))
Grads(cx.cache, ps)
y, gs
end

function addgrad!(gs₁, gs₂)
for p in ps
if gs₁[p] != nothing && gs₂[p] != nothing
gs₁[p] .+= gs₂[p]
elseif gs₂[p] != nothing
gs₁.grads[p] = gs₂[p]
end
end
gs₁
end

x = randn(Float32, 3, 10)
m = Chain(Dense(3,4), Dense(4,1))
m(x)
f = () -> sum(m(x));
ps = Flux.params(m)
gs = gradient(f, ps)
gsᵣ, gs_vec = initgrad(ps)
y = gradient!(f, gsᵣ, ps)
all(gsᵣ[p] ≈ gs[p] for p in ps)
``````

Having all parameters in a single array is much faster for reducing gradients

``````julia> gs₁, v₁  = initgrad(ps);
julia> gs₂, v₂  = initgrad(ps);
julia> gradient!(f, gs₁, ps);
julia> gradient!(f, gs₂, ps);
julia> @btime addgrad!(gs₁, gs₂);
2.994 μs (22 allocations: 1.34 KiB)

julia> @btime v₁ .+= v₂;
345.574 ns (2 allocations: 64 bytes)
``````

Which means the reduction is much faster.

Now the trouble I have now is that the solution is not sufficiently general.
Let’s say that I define a special type of matrix which is composed from a some trainable vectors (I do this frequently for example for matrices I keep in SVD form (see SumProductTransform.jl/svddense.jl at master · pevnak/SumProductTransform.jl · GitHub) or matrices handling missing values (Mill.jl/post_imputing_matrix.jl at master · CTUAvastLab/Mill.jl · GitHub).

A MWE for such a matrix would be something like this

``````struct IMatrix{T<:Any,W<:AbstractArray{T}, V<:AbstractArray{T}} <: AbstractMatrix{T}
w::W
v::V
end

Flux.@functor(IMatrix)

import Base.*
IMatrix(r::Int, c::Int) = IMatrix(randn(Float32,r,c), randn(Float32,r))

Base.size(m::IMatrix) = size(m.w)
Base.show(io::IO, m::IMatrix) = print(io, "IMatrix \$(size(m))")
*(a::IMatrix, b::AbstractMatrix) = a.w * b
*(a::AbstractMatrix, b::IMatrix) = a * b.w
``````

Now If I invoke `Flux.params(model)`, it will contain `IMatrix` which will break with the above scheme, as it assumes that the parameter is an array and not a composite type.

Therefore my question is, does anyone know, how to fix the above approach for storing grads in a vector which will permit to use matrices implemented as a composite type?

Thanks a lot for any hints.

Need `IMatrix` be a subtype of `AbstractMatrix`? If not, the `functor`ization should work properly. If so, you could overload `params!`.

Also RE manipulating grads, have you looked into using `Flux.destructure` to get a parameter vector?

Thanks for the answer.

I know if I do not subtype it as an AbstractMatrix, the problem is solved, but I would actually like it to be a subtype of AbstractMatrix and that is the problem.

I have not thought about using destructure / restructure for this, which is neat, but i guess my solution has smaller overhead, as I do need to copy the gradient (only on cpu of course). Would not destructure and restructure suffers the same problem?

Check out ComponentArrays

I did,

can you be please more specific to point me, how it solves my problem?

I know if I do not subtype it as an AbstractMatrix, the problem is solved, but I would actually like it to be a subtype of AbstractMatrix and that is the problem.

As @ToucheSir suggested, you can overload params!. E.g.

``````function Flux.params!(p::Params, x::IMatrix, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end
``````

So this partially works. I have implemented this idea, but the gradients of parameters are not stored in the gs properly. See

``````mi = Chain(Dense(IMatrix(4,3), randn(Float32,4), identity),
Dense(IMatrix(1,4), randn(Float32,1), identity))
ps = Flux.params(mi)
mi(x)
gs = gradient(() -> sum(mi(x)), ps)
[gs[p] for p in ps]
julia> [gs[p] for p in ps]
4-element Vector{Array{Float32, N} where N}:
[-0.052534547 0.31615198 0.34340757; 0.020920932 -0.1259018 -0.13675584; 1.0563443 -6.3570614 -6.9051056; -1.6292106 9.804562 10.649816]
[0.5259397, -0.20944595, -10.575391, 16.310537]
[3.0866313 5.0587535 24.052073 -4.2675037]
[10.0]
``````

but when I add

``````Flux.params!(p::Params, x::IMatrix, seen = IdSet()) = (push!(p, x.w);push!(p, x.v))
``````
``````ps = Flux.params(mi)
mi(x)
gs = gradient(() -> sum(mi(x)), ps)
julia> [gs[p] for p in ps]
6-element Vector{Union{Nothing, Vector{Float32}}}:
nothing
nothing
Float32[0.5259397, -0.20944595, -10.575391, 16.310537]
nothing
nothing
Float32[10.0]
``````

i.e. the gradients of parameters of IMatrices are not written correctly to the `Zygote.Grads` structure. I guess it might something to do with defining a custom gradient,
which should be a breeze, but i do not know how to do it such that they are correctly assigned. I have tried something like

``````Zygote.@adjoint function *(a::IMatrix, b::AbstractMatrix)
a * b, Δ -> begin
((w = Δ * b', v = nothing), a' * Δ)
end
end
``````

but it does not work.

Any help is appreciated.

I have improved my implementation to support `AbstractMatrices` implemented as `struct` even though I do not know, how the solution is general (I can post it later.). But I have encountered another behavior I do not understand and cannot understand from the source of Zygote.

The following snippet defines a `gradient!` whose sole purpose is to accumulate gradients of parameters into a prepared gradient structure such that I can treat all gradient with respect to parameters as a vector, which is nice for sending over networks, fast reduction in parallel computation of gradients etc.

``````using Flux, Zygote
using Zygote: Context, _pullback, cache, sensitivity, Grads

function initgrad2vec(ps::Flux.Params)
vec_ps = zeros(Float32, mapreduce(length, +, ps))
start = 1
cache = IdDict{Any,Any}()
for p in ps
stop = start + length(p) - 1
cache[p] = reshape(view(vec_ps, start:stop), size(p))
start = stop + 1
end
Grads(cache, ps), vec_ps
end

function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
cx = Context(gs.grads)
for p in ps
cx.cache[p] .= 0
end
y, back = _pullback(cx, f)
back(sensitivity(y))
y, Grads(cx.cache, ps)
end
``````

When I execute this

``````julia> w = randn(Float32, 1,3);

julia> x = randn(Float32, 3, 10);

julia> ps = Flux.Params([w]);

julia> f = () -> sum(w * x);

julia> gs, gs_vec = initgrad2vec(ps);

julia> gs[w]
1×3 reshape(view(::Vector{Float32}, 1:3), 1, 3) with eltype Float32:
0.0  0.0  0.0

julia> gradient!(f, gs, ps);

julia> gs[w]
1×3 Matrix{Float32}:
-2.92672  -0.160237  2.9666
``````

I can see that when I preallocate `Grads` structure in in `gs, gs_vec = initgrad2vec(ps)`, `gs[w]` contains view into `gs_vec`, but somehow (and I do not understang how and where), zygote overwrites this during calculation of gradient by a new matrix, since after I call `gradient!`, which I hope to be inplace, `gs[w]` contains a `Matrix`, not a `view`.

Can someone tell me, if my goal is possible and where / how / why Zygote overwrite it?

Thanks a lot.
Tomas

1 Like

I am not an expert, but I believe it is because Zygote doesn’t `.=` when updating a gradient in the cache (e.g. like you did `cx.cache[p] .= 0`). It will do `cx.cache[p] = the gradient`. Take it with a grain of salt, cause I could be totally wrong here.

Thanks for the answer. That would explain the behavior, but then I do not know, how Zygote would achieve accumulation of the gradient. I have thought that the accumulation occurs through the `Cache`.

It does accumulate gradient using `accum` (see here), but in-place mutation of the values is not require for accumulation. In your case, I think it will hit this line (called from here). As you can see, it does add the gradients, but `accum.(x, y)` will not result in the correct array type being returned back.

PS: still take all this with a grain of salt 1 Like

I see, this explains it.

I needed this kind of kick. Thanks a lot.

1 Like