How to efficiently differentiate backslash operator for sparse matrix?

The problem that you are left with is that Zygote doesn’t know how to pull back through A = sparse(i,j,v), and the solution is to provide this pullback. I believe the pullback for SparseMatrixCSC() is easier than that for sparse(), so what I did is to replace

A = sparse([1, 1, 2, 2], [1, 2, 1, 2], a)

with

A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)

and define

function ChainRulesCore.rrule(
    ::typeof(SparseArrays.SparseMatrixCSC),
    m::Integer, n::Integer,
    pp::Vector, ii::Vector, Av::Vector
)
    A = SparseMatrixCSC(m,n,pp,ii,Av)

    function SparseMatrixCSC_pullback(dA)
        # Pick out the entries in `dA` corresponding to nonzeros in `A`
        dAv = Vector{eltype(dA)}(undef, length(Av))
        for j = 1:n, p = pp[j]:pp[j+1]-1
            dAv[p] = dA[ii[p],j]
        end
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end

    return A, SparseMatrixCSC_pullback
end

This should work, but it isn’t quite as efficient as it could be. The problem is that the pullback for mybackslash computes a dense dA even though the pullback for SparseMatrixCSC() only looks at a subset of the entries of dA, so to do better we should only compute those entries of dA which are actually needed. This can be done as follows.

function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
    c = A\b
    function mybackslash_pullback(dc)
        db = A'\collect(dc)
        # ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
        # `A'\dc` throws an error for such `dc`. The `collect()` here is a
        # workaround for this bug.

        dA = @thunk begin
            m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
            dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
            for j = 1:n, p = pp[j]:pp[j+1]-1
                dAv[p] = -db[ii[p]] * c[j]
            end
            dA = SparseMatrixCSC(m,n,pp,ii,dAv)
        end
        return (NO_FIELDS, dA, db)
    end
    return c, mybackslash_pullback
end

Finally, you can optimise the SparseMatrixCSC() pullback a little bit by picking the entries in a more efficient way if dA is already a sparse matrix. I omit the details here and instead simply copy-paste the complete code.

using ChainRulesCore
using LinearAlgebra
using SparseArrays

mybackslash(A,B) = A\B

# For testing
function ChainRulesCore.rrule(::typeof(mybackslash), A, B)
    C = A\B
    function mybackslash_pullback(dC)
        dB = A'\dC
        return (NO_FIELDS, @thunk(-dB * C'), dB)
    end
    return C, mybackslash_pullback
end

function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
    c = A\b
    function mybackslash_pullback(dc)
        db = A'\collect(dc)
        # ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
        # `A'\dc` throws an error for such `dc`. The `collect()` here is a
        # workaround for this bug.

        dA = @thunk begin
            m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
            dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
            for j = 1:n, p = pp[j]:pp[j+1]-1
                dAv[p] = -db[ii[p]] * c[j]
            end
            dA = SparseMatrixCSC(m,n,pp,ii,dAv)
        end
        return (NO_FIELDS, dA, db)
    end
    return c, mybackslash_pullback
end

function ChainRulesCore.rrule(
    ::typeof(SparseArrays.SparseMatrixCSC),
    m::Integer, n::Integer,
    pp::Vector, ii::Vector, Av::Vector
)
    A = SparseMatrixCSC(m,n,pp,ii,Av)

    function SparseMatrixCSC_pullback(dA::AbstractMatrix)
        dAv = Vector{eltype(dA)}(undef, length(Av))
        for j = 1:n, p = pp[j]:pp[j+1]-1
            dAv[p] = dA[ii[p],j]
        end
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end
    function SparseMatrixCSC_pullback(dA::SparseMatrixCSC)
        @assert getproperty.(Ref(A), (:m,:n,:colptr,:rowval)) == getproperty.(Ref(dA), (:m,:n,:colptr,:rowval))
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dA.nzval)
    end

    return A, SparseMatrixCSC_pullback
end


using Test
using Zygote

function test()
    @testset "SparseMatrixCSC rrule" begin
        m,n = 3,2
        pp = [1,3,4]
        ii = [1,3,2]
         Av = collect(11.0:13.0)
        dAv = collect(21.0:23.0)

        dA_sparse = SparseMatrixCSC(m,n,pp,ii,dAv)
        dA_dense = [
            dAv[1]   0
              0    dAv[3]
            dAv[2]   0
        ]

        A,pb = rrule(SparseMatrixCSC, m,n,pp,ii,Av)
        @test A == SparseMatrixCSC(m,n,pp,ii,Av)
        @test pb(dA_dense ) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
        @test pb(dA_sparse) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end

    @testset "mybackslash" begin
        m,n = 3,3
        pp = [1,3,4,6]
        ii = [1,3,2,1,2]
        Av = collect(11.0:15.0)
         b = collect(21.0:23.0)
        dc = collect(31.0:33.0)

        As = SparseMatrixCSC(m,n,pp,ii,Av)
        Ad = Matrix(As)

        cd,pbd = rrule(mybackslash, Ad, b)
        _,dAd,dbd = unthunk.(pbd(dc))
        dAd[As .== 0] .= 0
        cs,pbs = rrule(mybackslash, As, b)
        _,dAs,dbs = unthunk.(pbs(dc))

        @test cd ≈ cs
        @test dAd ≈ dAs
        @test dbd ≈ dbs
    end

    @testset "Zygote integration" begin
        # This currently only checks that things compile

        function myfun(a)
            #A = [a[1] a[2]; a[3] a[4]]
            A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
            b = [1.0;1.0]
            x = mybackslash(A,b)
            return sum(x)
        end

        p = [6.0,1.0,1.0,3.0]
        @show g = gradient(myfun, p)
    end
end
4 Likes