Flux.params of a matrix implemented as a struct

I have improved my implementation to support AbstractMatrices implemented as struct even though I do not know, how the solution is general (I can post it later.). But I have encountered another behavior I do not understand and cannot understand from the source of Zygote.

The following snippet defines a gradient! whose sole purpose is to accumulate gradients of parameters into a prepared gradient structure such that I can treat all gradient with respect to parameters as a vector, which is nice for sending over networks, fast reduction in parallel computation of gradients etc.

using Flux, Zygote
using Zygote: Context, _pullback, cache, sensitivity, Grads

function initgrad2vec(ps::Flux.Params)
	vec_ps = zeros(Float32, mapreduce(length, +, ps))
	start = 1
	cache = IdDict{Any,Any}()
	for p in ps
		stop = start + length(p) - 1
		cache[p] = reshape(view(vec_ps, start:stop), size(p))
		start = stop + 1
	end
	Grads(cache, ps), vec_ps
end

function gradient!(f, gs::Zygote.Grads, ps::Flux.Params)
	cx = Context(gs.grads)
	for p in ps
	  cx.cache[p] .= 0
	end
	y, back = _pullback(cx, f)
	back(sensitivity(y))
	y, Grads(cx.cache, ps)
end

When I execute this

julia> w = randn(Float32, 1,3);

julia> x = randn(Float32, 3, 10);

julia> ps = Flux.Params([w]);

julia> f = () -> sum(w * x);

julia> gs, gs_vec = initgrad2vec(ps);

julia> gs[w]
1×3 reshape(view(::Vector{Float32}, 1:3), 1, 3) with eltype Float32:
 0.0  0.0  0.0

julia> gradient!(f, gs, ps);

julia> gs[w]
1×3 Matrix{Float32}:
 -2.92672  -0.160237  2.9666

I can see that when I preallocate Grads structure in in gs, gs_vec = initgrad2vec(ps), gs[w] contains view into gs_vec, but somehow (and I do not understang how and where), zygote overwrites this during calculation of gradient by a new matrix, since after I call gradient!, which I hope to be inplace, gs[w] contains a Matrix, not a view.

Can someone tell me, if my goal is possible and where / how / why Zygote overwrite it?

Thanks a lot.
Tomas

1 Like