Dear All,
I have a student who is implementing few algorithms for distributed optimization (training of neural networks) using Flux / Zygote tooling. We have been tracking various sources of inefficiencies and I have found out that manipulating (adding) gradients stored in Zygote.Grads
(essentially an IdDict) poses quite some overhead. Therefore I got an idea, that I will pre-allocate a large vector for all gradients, such that they are stored in a single continuous vector and initiate Grads
with views into this array (see MWE below)
using Flux, Zygote, BenchmarkTools
using Zygote: Context, _pullback, cache, sensitivity, Grads
function initgrad(ps)
l = mapreduce(length, +, ps)
vec_ps = zeros(Float32, l)
start = 1
cache = IdDict{Any,Any}()
for p in ps
stop = start + length(p) - 1
cache[p] = reshape(view(vec_ps, start:stop), size(p))
start = stop + 1
end
cache
Grads(cache, ps), vec_ps
end
function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
cx = Context(gs.grads)
for p in ps
cache(cx)[p] .= 0
end
y, back = _pullback(cx, f)
back(sensitivity(y))
Grads(cx.cache, ps)
y, gs
end
function addgrad!(gs₁, gs₂)
for p in ps
if gs₁[p] != nothing && gs₂[p] != nothing
gs₁[p] .+= gs₂[p]
elseif gs₂[p] != nothing
gs₁.grads[p] = gs₂[p]
end
end
gs₁
end
x = randn(Float32, 3, 10)
m = Chain(Dense(3,4), Dense(4,1))
m(x)
f = () -> sum(m(x));
ps = Flux.params(m)
gs = gradient(f, ps)
gsᵣ, gs_vec = initgrad(ps)
y = gradient!(f, gsᵣ, ps)[1]
all(gsᵣ[p] ≈ gs[p] for p in ps)
Having all parameters in a single array is much faster for reducing gradients
julia> gs₁, v₁ = initgrad(ps);
julia> gs₂, v₂ = initgrad(ps);
julia> gradient!(f, gs₁, ps);
julia> gradient!(f, gs₂, ps);
julia> @btime addgrad!(gs₁, gs₂);
2.994 μs (22 allocations: 1.34 KiB)
julia> @btime v₁ .+= v₂;
345.574 ns (2 allocations: 64 bytes)
Which means the reduction is much faster.
Now the trouble I have now is that the solution is not sufficiently general.
Let’s say that I define a special type of matrix which is composed from a some trainable vectors (I do this frequently for example for matrices I keep in SVD form (see https://github.com/pevnak/SumProductTransform.jl/blob/master/src/layers/svddense.jl) or matrices handling missing values (https://github.com/CTUAvastLab/Mill.jl/blob/master/src/special_arrays/post_imputing_matrix.jl).
A MWE for such a matrix would be something like this
struct IMatrix{T<:Any,W<:AbstractArray{T}, V<:AbstractArray{T}} <: AbstractMatrix{T}
w::W
v::V
end
Flux.@functor(IMatrix)
import Base.*
IMatrix(r::Int, c::Int) = IMatrix(randn(Float32,r,c), randn(Float32,r))
Base.size(m::IMatrix) = size(m.w)
Base.show(io::IO, m::IMatrix) = print(io, "IMatrix $(size(m))")
*(a::IMatrix, b::AbstractMatrix) = a.w * b
*(a::AbstractMatrix, b::IMatrix) = a * b.w
Now If I invoke Flux.params(model)
, it will contain IMatrix
which will break with the above scheme, as it assumes that the parameter is an array and not a composite type.
Therefore my question is, does anyone know, how to fix the above approach for storing grads in a vector which will permit to use matrices implemented as a composite type?
Thanks a lot for any hints.