I have improved my implementation to support AbstractMatrices
implemented as struct
even though I do not know, how the solution is general (I can post it later.). But I have encountered another behavior I do not understand and cannot understand from the source of Zygote.
The following snippet defines a gradient!
whose sole purpose is to accumulate gradients of parameters into a prepared gradient structure such that I can treat all gradient with respect to parameters as a vector, which is nice for sending over networks, fast reduction in parallel computation of gradients etc.
using Flux, Zygote
using Zygote: Context, _pullback, cache, sensitivity, Grads
function initgrad2vec(ps::Flux.Params)
vec_ps = zeros(Float32, mapreduce(length, +, ps))
start = 1
cache = IdDict{Any,Any}()
for p in ps
stop = start + length(p) - 1
cache[p] = reshape(view(vec_ps, start:stop), size(p))
start = stop + 1
end
Grads(cache, ps), vec_ps
end
function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
cx = Context(gs.grads)
for p in ps
cx.cache[p] .= 0
end
y, back = _pullback(cx, f)
back(sensitivity(y))
y, Grads(cx.cache, ps)
end
When I execute this
julia> w = randn(Float32, 1,3);
julia> x = randn(Float32, 3, 10);
julia> ps = Flux.Params([w]);
julia> f = () -> sum(w * x);
julia> gs, gs_vec = initgrad2vec(ps);
julia> gs[w]
1×3 reshape(view(::Vector{Float32}, 1:3), 1, 3) with eltype Float32:
0.0 0.0 0.0
julia> gradient!(f, gs, ps);
julia> gs[w]
1×3 Matrix{Float32}:
-2.92672 -0.160237 2.9666
I can see that when I preallocate Grads
structure in in gs, gs_vec = initgrad2vec(ps)
, gs[w]
contains view into gs_vec
, but somehow (and I do not understang how and where), zygote overwrites this during calculation of gradient by a new matrix, since after I call gradient!
, which I hope to be inplace, gs[w]
contains a Matrix
, not a view
.
Can someone tell me, if my goal is possible and where / how / why Zygote overwrite it?
Thanks a lot.
Tomas