Zygote.jl: How to get the gradient of sparse matrix

Hi there, I am trying to differentiate through sparse matrix, and here is the MWE:

using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote

Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
    SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
  end

function test1(a)
    A = sparse([1, 1, 2, 2],[1, 2, 1, 2],a, 2, 2)
    return sum(A)
end 

a = [1.0,2.0,3.0,4.0]
gradient(test1, a)

Then I got error:

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#399#400")(::Nothing) at /root/.julia/packages/Zygote/CgsVi/src/lib/array.jl:58
 [3] (::Zygote.var"#2265#back#401"{Zygote.var"#399#400"})(::Nothing) at /root/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] sparse! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:862 [inlined]
 [5] (::typeof(∂(sparse!)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [6] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:703 [inlined]
 [7] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [8] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:892 [inlined]
 [9] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [10] test1 at /root/codes/test_zygote/test_sparse.jl:11 [inlined]
 [11] (::typeof(∂(test1)))(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#41#42"{typeof(∂(test1))})(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:41
 [13] gradient(::Function, ::Array{Float64,1}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:59
 [14] top-level scope at /root/codes/test_zygote/test_sparse.jl:16
in expression starting at /root/codes/test_zygote/test_sparse.jl:16

Then I changed sparse to SparseMatrixCSC, it works:

using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote

Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
    SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
  end

function test2(a)
    A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
    return sum(A)
end 

a = [1.0,2.0,3.0,4.0]
gradient(test2, a)

output

([1.0, 1.0, 1.0, 1.0],)

As far as I am concerned, calling sparse will also create SparseMatrixCSC, so my questions are:

  1. I know when two indices are the same in the input index array, calling sparse will add the values of duplicated entries, is this the root cause of “Mutating arrays is not supported” when calling sparse?
  2. How can I directly differentiate sparse? Because I know calling SparseMatrixCSC rather than the exposed API sparse is not encouraged.

Thx for any reply.

I am facing similar problem. Did you ever manage to solve this problem?

Zygote can’t differentiate through sparse-matrix constructors AFAIK. You need to write a custom rrule for that part using ChainRulesCore.jl.

(Even with AD, at some point you need to learn to take derivatives yourself, at least for some pieces of your calculation.)

2 Likes

While ChainRulesCore is the recommended way to do things now, custom adjoints can also be defined in Zygote directly like @Richard-Li did.

The SparseMatricCSC constructor is part of the API, so don’t feel bad about using it :slight_smile: Unfortunately, the function sparse mutates a lot of things before calling said constructor (see the source).
So I suggest you assess if you really need sparse or if you can use the constructor directly. This will tell you which adjoint(s) to write. The one you made looks correct at first glance but I didn’t think about it for long.

1 Like

I suggest writing an rrule for sparse and contributing it to ChainRules.jl. It’s doable and pretty straightforward.

3 Likes

Even if you do this, I’m worried that Zygote (or Enzyme) will still try to construct a dense matrix for the primal-tangent input to the rrule’s pullback?

For example, consider something as simple as the scalar-valued function f(p) = x^T A(p) y, where A(p) constructs an \ell \times m sparse matrix from some parameters p \in \mathbb{R}^n, while x \in \mathbb{R}^\ell and y \in \mathbb{R}^m are (dense) constant vectors. The partial derivatives are \frac{\partial f}{\partial p_k} = x^T \frac{\partial A}{\partial p_k} y = \mathrm{tr}[yx^T \frac{\partial A}{\partial p_k}] = (xy^T) \cdot \frac{\partial A}{\partial p_k}, where \cdot is the Frobenius inner product. In a sparse situation where \frac{\partial A}{\partial p_k} has only O(1) nonzero entries, then \frac{\partial f}{\partial p_k} can be computed in O(1) operations and the whole \nabla_p f can be computed in O(n) operations with O(n) storage (like the calculation of f(p) itself) :smiley: — this could easily be implemented in an rrule for f(p). However, if you instead define an rrule for A(p) (or for the sparse constructor), then the input tangent vector to the A(p) pullback is the rank-1 matrix xy^T, and if Zygote stores this as a dense matrix it will require O(\ell m) storage and time :frowning_face:.

Can Zygote (or Enzyme) be easily taught to store a low-rank tangent like xy^T implicitly?

What if we define an rrule for f(p) that calls back into the rrule for A(p)? Best of both worlds?

You then lose the generic benefit of defining an rrule for the sparse constructor itself — you still need a manual rrule for any function that uses sparse matrices, and to get the full benefit of reverse-mode AD you need to manually implement the chain rule connecting the sparse-matrix constructor(s) all the way to the first low-dimensional (e.g. scalar) outputs.

That is, it’s basically the same as the current situation.

Yes. This starts to sound like the wrong level to solve the problem. Instead of writing special sparse rules for x' * A * y etc, to complement the sparse forward evaluation, perhaps the right level is to opt out of the existing rule for dense x' * A * y and instead differentiate the sparse forward implementation.

This is obviously what ForwardDiff does. It’s not impossible to make Zygote do this, although it’s going to involve a lot of indexing (which without thunks is expensive). It’s possible that Enzyme is already efficient at this, or could be made so.

I hope you’re not suggesting forward-mode AD here? That doesn’t scale to a large number of input parameters.

I think the answer is yes but I will need some work to back that up with a full example. Loss of sparsity and structure in the co-tangent is a problem that hasn’t received enough attention as far as I can tell, but not because it is technically impossible. Zygote passes special arrays and types as co-tangents all the time. And ChainRulesCore has a whole mechanism to project the input’s co-tangent onto the structure of the primal input. In theory, one can define a lazy low rank matrix like the following and then propagate that backward.

julia> using LazyArrays

julia> x = rand(1000);

julia> y = @~ x .* x';

julia> Base.summarysize(x)
8040

julia> Base.summarysize(y)
8096

Lazy arrays are not used as much in ChainRules but I think they should where possible.

1 Like

Here is an example adapted from ChainRules.

using LazyArrays, ChainRulesCore, LinearAlgebra, Zygote

mydot(x, A, y) = dot(x, A, y)

function ChainRulesCore.rrule(::typeof(mydot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
    z = dot(x, A, y)
    function dot_pullback(Ω̄)
        Ay = @~ A * y
        ΔΩ = unthunk(Ω̄)
        cΔΩ = conj(ΔΩ)
        dx = @~(cΔΩ .* Ay)
        ay = adjoint(y)
        dA = @~(ΔΩ .* x .* ay)
        aA = adjoint(A)
        dy = @~(ΔΩ .* (aA * x))
        return (NoTangent(), dx, dA, dy)
    end
    dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent())
    return z, dot_pullback
end
julia> x = rand(200); A = rand(200, 300); y = rand(300);

julia> Zygote.pullback(mydot, x, A, y)[2](1.0)[2] |> Base.summarysize
4184

julia> Base.summarysize(x) + Base.summarysize(y)
4080
2 Likes