Zygote.jl: How to get the gradient of sparse matrix

Even if you do this, I’m worried that Zygote (or Enzyme) will still try to construct a dense matrix for the primal-tangent input to the rrule’s pullback?

For example, consider something as simple as the scalar-valued function f(p) = x^T A(p) y, where A(p) constructs an \ell \times m sparse matrix from some parameters p \in \mathbb{R}^n, while x \in \mathbb{R}^\ell and y \in \mathbb{R}^m are (dense) constant vectors. The partial derivatives are \frac{\partial f}{\partial p_k} = x^T \frac{\partial A}{\partial p_k} y = \mathrm{tr}[yx^T \frac{\partial A}{\partial p_k}] = (xy^T) \cdot \frac{\partial A}{\partial p_k}, where \cdot is the Frobenius inner product. In a sparse situation where \frac{\partial A}{\partial p_k} has only O(1) nonzero entries, then \frac{\partial f}{\partial p_k} can be computed in O(1) operations and the whole \nabla_p f can be computed in O(n) operations with O(n) storage (like the calculation of f(p) itself) :smiley: — this could easily be implemented in an rrule for f(p). However, if you instead define an rrule for A(p) (or for the sparse constructor), then the input tangent vector to the A(p) pullback is the rank-1 matrix xy^T, and if Zygote stores this as a dense matrix it will require O(\ell m) storage and time :frowning_face:.

Can Zygote (or Enzyme) be easily taught to store a low-rank tangent like xy^T implicitly?