Flux.params of a matrix implemented as a struct

Dear All,

I have a student who is implementing few algorithms for distributed optimization (training of neural networks) using Flux / Zygote tooling. We have been tracking various sources of inefficiencies and I have found out that manipulating (adding) gradients stored in Zygote.Grads (essentially an IdDict) poses quite some overhead. Therefore I got an idea, that I will pre-allocate a large vector for all gradients, such that they are stored in a single continuous vector and initiate Grads with views into this array (see MWE below)

using Flux, Zygote, BenchmarkTools
using Zygote: Context, _pullback, cache, sensitivity, Grads

function initgrad(ps)
	l = mapreduce(length, +, ps)
	vec_ps = zeros(Float32, l)
	start = 1
	cache = IdDict{Any,Any}()
	for p in ps
		stop = start + length(p) - 1
		cache[p] = reshape(view(vec_ps, start:stop), size(p))
		start = stop + 1
	end
	cache
	Grads(cache, ps), vec_ps
end

function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
	cx = Context(gs.grads)
	for p in ps
	  cache(cx)[p] .= 0
	end
	y, back = _pullback(cx, f)
	back(sensitivity(y))
	Grads(cx.cache, ps)
	y, gs 
end

function addgrad!(gs₁, gs₂)
  for p in ps 
      if gs₁[p] != nothing && gs₂[p] != nothing 
          gs₁[p] .+= gs₂[p]
      elseif gs₂[p] != nothing
          gs₁.grads[p] = gs₂[p]
      end
  end
  gs₁
end

x = randn(Float32, 3, 10)
m = Chain(Dense(3,4), Dense(4,1))
m(x)
f = () -> sum(m(x));
ps = Flux.params(m)
gs = gradient(f, ps)
gsᵣ, gs_vec = initgrad(ps)
y = gradient!(f, gsᵣ, ps)[1]
all(gsᵣ[p] ≈ gs[p] for p in ps)

Having all parameters in a single array is much faster for reducing gradients

julia> gs₁, v₁  = initgrad(ps);
julia> gs₂, v₂  = initgrad(ps);
julia> gradient!(f, gs₁, ps);
julia> gradient!(f, gs₂, ps);
julia> @btime addgrad!(gs₁, gs₂);
  2.994 μs (22 allocations: 1.34 KiB)

julia> @btime v₁ .+= v₂;
  345.574 ns (2 allocations: 64 bytes)

Which means the reduction is much faster.

Now the trouble I have now is that the solution is not sufficiently general.
Let’s say that I define a special type of matrix which is composed from a some trainable vectors (I do this frequently for example for matrices I keep in SVD form (see SumProductTransform.jl/svddense.jl at master · pevnak/SumProductTransform.jl · GitHub) or matrices handling missing values (Mill.jl/post_imputing_matrix.jl at master · CTUAvastLab/Mill.jl · GitHub).

A MWE for such a matrix would be something like this

struct IMatrix{T<:Any,W<:AbstractArray{T}, V<:AbstractArray{T}} <: AbstractMatrix{T}
	w::W
	v::V
end

Flux.@functor(IMatrix)

import Base.*
IMatrix(r::Int, c::Int) = IMatrix(randn(Float32,r,c), randn(Float32,r))

Base.size(m::IMatrix) = size(m.w)
Base.show(io::IO, m::IMatrix) = print(io, "IMatrix $(size(m))")
*(a::IMatrix, b::AbstractMatrix) = a.w * b
*(a::AbstractMatrix, b::IMatrix) = a * b.w

Now If I invoke Flux.params(model), it will contain IMatrix which will break with the above scheme, as it assumes that the parameter is an array and not a composite type.

Therefore my question is, does anyone know, how to fix the above approach for storing grads in a vector which will permit to use matrices implemented as a composite type?

Thanks a lot for any hints.

Need IMatrix be a subtype of AbstractMatrix? If not, the functorization should work properly. If so, you could overload params!.

Also RE manipulating grads, have you looked into using Flux.destructure to get a parameter vector?

Thanks for the answer.

I know if I do not subtype it as an AbstractMatrix, the problem is solved, but I would actually like it to be a subtype of AbstractMatrix and that is the problem.

I have not thought about using destructure / restructure for this, which is neat, but i guess my solution has smaller overhead, as I do need to copy the gradient (only on cpu of course). Would not destructure and restructure suffers the same problem?

Check out ComponentArrays

I did,

can you be please more specific to point me, how it solves my problem?

I know if I do not subtype it as an AbstractMatrix, the problem is solved, but I would actually like it to be a subtype of AbstractMatrix and that is the problem.

As @ToucheSir suggested, you can overload params!. E.g.

function Flux.params!(p::Params, x::IMatrix, seen = IdSet())
  x in seen && return
  push!(seen, x)
  for child in trainable(x)
    params!(p, child, seen)
  end
end

So this partially works. I have implemented this idea, but the gradients of parameters are not stored in the gs properly. See

mi = Chain(Dense(IMatrix(4,3), randn(Float32,4), identity), 
		Dense(IMatrix(1,4), randn(Float32,1), identity))
ps = Flux.params(mi)
mi(x)
gs = gradient(() -> sum(mi(x)), ps)
[gs[p] for p in ps]
julia> [gs[p] for p in ps]
4-element Vector{Array{Float32, N} where N}:
 [-0.052534547 0.31615198 0.34340757; 0.020920932 -0.1259018 -0.13675584; 1.0563443 -6.3570614 -6.9051056; -1.6292106 9.804562 10.649816]
 [0.5259397, -0.20944595, -10.575391, 16.310537]
 [3.0866313 5.0587535 24.052073 -4.2675037]
 [10.0]

but when I add

Flux.params!(p::Params, x::IMatrix, seen = IdSet()) = (push!(p, x.w);push!(p, x.v))
ps = Flux.params(mi)
mi(x)
gs = gradient(() -> sum(mi(x)), ps)
julia> [gs[p] for p in ps]
6-element Vector{Union{Nothing, Vector{Float32}}}:
 nothing
 nothing
 Float32[0.5259397, -0.20944595, -10.575391, 16.310537]
 nothing
 nothing
 Float32[10.0]

i.e. the gradients of parameters of IMatrices are not written correctly to the Zygote.Grads structure. I guess it might something to do with defining a custom gradient,
which should be a breeze, but i do not know how to do it such that they are correctly assigned. I have tried something like

Zygote.@adjoint function *(a::IMatrix, b::AbstractMatrix)
	a * b, Δ -> begin 
		((w = Δ * b', v = nothing), a' * Δ)
	end
end

but it does not work.

Any help is appreciated.

I have improved my implementation to support AbstractMatrices implemented as struct even though I do not know, how the solution is general (I can post it later.). But I have encountered another behavior I do not understand and cannot understand from the source of Zygote.

The following snippet defines a gradient! whose sole purpose is to accumulate gradients of parameters into a prepared gradient structure such that I can treat all gradient with respect to parameters as a vector, which is nice for sending over networks, fast reduction in parallel computation of gradients etc.

using Flux, Zygote
using Zygote: Context, _pullback, cache, sensitivity, Grads

function initgrad2vec(ps::Flux.Params)
	vec_ps = zeros(Float32, mapreduce(length, +, ps))
	start = 1
	cache = IdDict{Any,Any}()
	for p in ps
		stop = start + length(p) - 1
		cache[p] = reshape(view(vec_ps, start:stop), size(p))
		start = stop + 1
	end
	Grads(cache, ps), vec_ps
end

function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
	cx = Context(gs.grads)
	for p in ps
	  cx.cache[p] .= 0
	end
	y, back = _pullback(cx, f)
	back(sensitivity(y))
	y, Grads(cx.cache, ps)
end

When I execute this

julia> w = randn(Float32, 1,3);

julia> x = randn(Float32, 3, 10);

julia> ps = Flux.Params([w]);

julia> f = () -> sum(w * x);

julia> gs, gs_vec = initgrad2vec(ps);

julia> gs[w]
1×3 reshape(view(::Vector{Float32}, 1:3), 1, 3) with eltype Float32:
 0.0  0.0  0.0

julia> gradient!(f, gs, ps);

julia> gs[w]
1×3 Matrix{Float32}:
 -2.92672  -0.160237  2.9666

I can see that when I preallocate Grads structure in in gs, gs_vec = initgrad2vec(ps), gs[w] contains view into gs_vec, but somehow (and I do not understang how and where), zygote overwrites this during calculation of gradient by a new matrix, since after I call gradient!, which I hope to be inplace, gs[w] contains a Matrix, not a view.

Can someone tell me, if my goal is possible and where / how / why Zygote overwrite it?

Thanks a lot.
Tomas

1 Like

I am not an expert, but I believe it is because Zygote doesn’t .= when updating a gradient in the cache (e.g. like you did cx.cache[p] .= 0). It will do cx.cache[p] = the gradient. Take it with a grain of salt, cause I could be totally wrong here.

Thanks for the answer. That would explain the behavior, but then I do not know, how Zygote would achieve accumulation of the gradient. I have thought that the accumulation occurs through the Cache.

It does accumulate gradient using accum (see here), but in-place mutation of the values is not require for accumulation. In your case, I think it will hit this line (called from here). As you can see, it does add the gradients, but accum.(x, y) will not result in the correct array type being returned back.

PS: still take all this with a grain of salt :grinning:

1 Like

I see, this explains it.

I needed this kind of kick. Thanks a lot.

1 Like