How to efficiently differentiate backslash operator for sparse matrix?

My problem is to calculate the gradient, dL(u)/dp, where L() is a loss function which takes a vector u as the input and outputs a scalar value. The vector p are parameters that defines a large sparse matrix K, such that K(p)u = F. And F is just a constant vector. As you may know, Ku = F is what we need to solve for structural finite element analysis.

Since L(u) is a scalar and p is a long vector, so I would use ReverseDiff to calculate dL(u)/dp. However, the current method in the existing packages(forward or reverse) does not support sparse matrix.

It works if K is a full matrix like this:

using ReverseDiff: gradient, jacobian

function myfun(p)
    K = [p[1] p[2]; p[3] p[4]]
    u = K\[1.0;1.0]
end
@show jacobian(myfun, [2.0, 3.0, 3.0,4.0])

But the code below does NOT work:
BTW, I can make this work by replacing \ with IterativeSolver: cg. But the gradient calculated in this way is not right. I won’t put details here but let me know if you know why.

using ReverseDiff: gradient, jacobian
using SparseArrays

function myfun(p)
    K = sparse([1 1 2 2], [1 2 1 2],p)
    u = K\[1.0;1.0]
end
@show jacobian(myfun, [2.0, 3.0, 3.0,4.0])

I learnt from somewhere that JuAFEM is a differentiable finite element code. But I haven’t used yet. So I want to know is JuAFEM capable of differentiate Ku = F? I guess the code has to do this calculation at some point to solve for u. And if the answer is yes, anyone can tell me how it is done? Does it define an custom adjoint in code? Thanks.

I think that adding the relevant missing methods may be the easiest solution. If you don’t get any specific replies here, I would suggest asking in an issue at

It’s already in Zygote:

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L404-L442

Hi, Chris
Maybe I am not clear in the title(just edited), what I want to do is for the sparse matrix. I don’t think the code you mentioned can be applied to backslash of sparse matrix.I provide the examples below for some cases I tested. Please run it if you have time. Let me know if you find me wrong or something new. Thank you for the help.

Some description about this code:

  1. First create a 2*2 symmetric positive definite matrix, A
  2. solve Ax = u using cg to check correctness
  3. calculate gradient dx/dAij using finite difference
  4. calculate gradient g1 = dx/dAij using backslash when A is dense. WORK
  5. calculate gradient g2 = dx/dAij using cg when A is sparse. Work only when A is symmetric and constant diagonal, i.e. aii = constant.
    6.calculate gradient g3 = dx/dAij using backslash when A is sparse. Throw error: no method mething lu!(sparsematrix…)
using Zygote
using SparseArrays
using IterativeSolvers
using Zygote: forwarddiff

# 2*2 matrix index for sparse matrix, A
i = [1, 1, 2, 2];
j = [1, 2, 1, 2];
a11 = 6.0; a12 = 1.0; a21 = 1.0; a22 = 3.0;
A = sparse(i,j,[a11, a12, a21, a22])

# right hand side of Ax = u
u = [1.0;1.0]

# check cg correctness to solve Ax = u
x = cg(A,u)
res = A*x-u
println("cg method residual ",res)

# gradient calculation using finite difference
eps = 0.00001;
A = [a11 a12; a21 a22];
Ad = [a11+eps a12; a21 a22];
e1 = (sum(Ad\u) - sum(A\u))./eps
Ad = [a11 a12+eps; a21 a22];
e2 = (sum(Ad\u) - sum(A\u))./eps
Ad = [a11 a12; a21+eps a22];
e3 = (sum(Ad\u) - sum(A\u))./eps
Ad = [a11 a12; a21 a22+eps];
e4 = (sum(Ad\u) - sum(A\u))./eps
println("gradient by finite difference  ", (e1, e2, e3, e4))

# gradient calculation when A is full/dense, works
g1 = gradient(a11, a12, a21, a22) do a, b, c, d
  forwarddiff([a,b,c,d]) do (a,b,c,d)
    A = [a b;c d]
    x = A\u
    sum(x)
  end
end
println("backslash of full matrix  ", g1)


# gradient calculation when A is sparse
# works using cg only when A is symmetric and constant along diagonal, i.e a11 = a22.
g2 = gradient(a11, a12, a21, a22) do a, b, c, d
  forwarddiff([a,b,c,d]) do (a,b,c,d)
    A = sparse(i,j,[a, b, c, d])
    x = cg(A,u)
    sum(x)
  end
end
println("cg of sparse   ", g2)

# gradient calculation when A is sparse
# not working using backslash
g3 = gradient(a11, a12, a21, a22) do a, b, c, d
  forwarddiff([a,b,c,d]) do (a,b,c,d)
    A = sparse(i,j,[a, b, c, d])
    x = A\u 
    sum(x)
  end
end
println("backslash of spasrse   ", g3)

What’s the recommended way to get a Jacobian out of Zygote without using forward_jacobian (which would go through the forward mode, and I suspect this would not work)?

Hi Tamas
I really want to write my own method, although I don’t think I can make it. I followed the documentation [here] (Custom Adjoints · Zygote) for Zygote. But I will take a look at the chainrule and try first and see how far I can get.

Use it to define for ReverseDiff. Should be copy paste @grad.

Forward mode would work, and you may want to use GitHub - JuliaDiff/SparseDiffTools.jl: Fast jacobian computation through sparsity exploitation and matrix coloring here. But if you really need to reverse, then this would do it:

using Zygote
function jacobian(f,x)
    y,back  = Zygote.pullback(f,x)
    k  = length(y)
    n  = length(x)
    J  = Matrix{eltype(y)}(undef,k,n)
    e_i = zero(x)
    for i = 1:k
        e_i[i] = oneunit(eltype(x))
        J[i,:] = back(e_i)[1]
        e_i[i] = zero(eltype(x))
    end
    J
end

hessian(f, x) = jacobian(x -> Zygote.gradient(f, x)[1], x)

f(x) = [x[1],x[2]^2 + x[1]]
x = [1.0,2.0]
jacobian(f,x)

using ForwardDiff
ForwardDiff.jacobian(f,x)
1 Like

Do you mean to define a custom rule for ReverseDiff? Is the functionality of @grad similar to @adjoint?

it’s almost exactly the same

Could you provide a simple example of how to use @grad? I was trying to do this, but it is not working.

using ReverseDiff: gradient,  @grad
minus(a,b) = a - b
@grad function minus(a, b)
  return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end
@show gradient(minus, 1,1)

interesting, that looks correct.

Oh, the values might have to be arrays?

Hi, Chris
Here is what I did. I define a function called mybackslash and its backward version. It still only works for dense matrix. When matrix A in myfun is sparse, it gives me an error.

using ReverseDiff: gradient, jacobian, @grad
using LinearAlgebra
using SparseArrays

function mybackslash(A,B)
    return A \ B
end

@grad function mybackslash(A,B)
   Y = A \ B
   return Y, function(Ȳ)
     B̄ = A' \ Ȳ
     return (-B̄ * Y', B̄)
   end
 end

function myfun(a)
    #A = [a[1] a[2]; a[3] a[4]]
    A = sparse([1, 1, 2, 2], [1, 2, 1, 2],a)
    b = [1.0;1.0]
    
    x = mybackslash(A,b)
    return sum(x)
end

p = [6.0,1.0,1.0,3.0]
g = gradient(myfun, p)

Oh I guess ReverseDiff’s type support is getting in the way. Is there a reason you can’t Zygote this?

1 Like

For Zygote, I basically replace @grad with @adjoint . But it still throws error which says something like "no method matching lu!(::sparseMatrixCSC…)

What’s calling the lu!? Just make it an allocating one.

Sorry about keeping asking you. Actually when I run the following code using Zygote, the error is “Need an adjoint for constructor SparseMatrixCSC{Float64,Int64}. Gradient is of type Array{Float64,2}

using Zygote
using Zygote: @adjoint
using LinearAlgebra
using SparseArrays

function mybackslash(A::SparseMatrixCSC,B)
    return A \ B
end

@adjoint function mybackslash(A,B)
  Y = A \ B
  return Y, function(Ȳ)
    B̄ = A' \ Ȳ
    return (-B̄ * Y', B̄)
  end
end

function myfun(a)
    #A = [a[1] a[2]; a[3] a[4]]
    A = sparse([1, 1, 2, 2], [1, 2, 1, 2],a)
    b = [1.0;1.0]
    
    x = mybackslash(A,b)
    return sum(x)
end

p = [6.0,1.0,1.0,3.0]
g = gradient(myfun, p)

The problem that you are left with is that Zygote doesn’t know how to pull back through A = sparse(i,j,v), and the solution is to provide this pullback. I believe the pullback for SparseMatrixCSC() is easier than that for sparse(), so what I did is to replace

A = sparse([1, 1, 2, 2], [1, 2, 1, 2], a)

with

A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)

and define

function ChainRulesCore.rrule(
    ::typeof(SparseArrays.SparseMatrixCSC),
    m::Integer, n::Integer,
    pp::Vector, ii::Vector, Av::Vector
)
    A = SparseMatrixCSC(m,n,pp,ii,Av)

    function SparseMatrixCSC_pullback(dA)
        # Pick out the entries in `dA` corresponding to nonzeros in `A`
        dAv = Vector{eltype(dA)}(undef, length(Av))
        for j = 1:n, p = pp[j]:pp[j+1]-1
            dAv[p] = dA[ii[p],j]
        end
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end

    return A, SparseMatrixCSC_pullback
end

This should work, but it isn’t quite as efficient as it could be. The problem is that the pullback for mybackslash computes a dense dA even though the pullback for SparseMatrixCSC() only looks at a subset of the entries of dA, so to do better we should only compute those entries of dA which are actually needed. This can be done as follows.

function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
    c = A\b
    function mybackslash_pullback(dc)
        db = A'\collect(dc)
        # ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
        # `A'\dc` throws an error for such `dc`. The `collect()` here is a
        # workaround for this bug.

        dA = @thunk begin
            m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
            dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
            for j = 1:n, p = pp[j]:pp[j+1]-1
                dAv[p] = -db[ii[p]] * c[j]
            end
            dA = SparseMatrixCSC(m,n,pp,ii,dAv)
        end
        return (NO_FIELDS, dA, db)
    end
    return c, mybackslash_pullback
end

Finally, you can optimise the SparseMatrixCSC() pullback a little bit by picking the entries in a more efficient way if dA is already a sparse matrix. I omit the details here and instead simply copy-paste the complete code.

using ChainRulesCore
using LinearAlgebra
using SparseArrays

mybackslash(A,B) = A\B

# For testing
function ChainRulesCore.rrule(::typeof(mybackslash), A, B)
    C = A\B
    function mybackslash_pullback(dC)
        dB = A'\dC
        return (NO_FIELDS, @thunk(-dB * C'), dB)
    end
    return C, mybackslash_pullback
end

function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
    c = A\b
    function mybackslash_pullback(dc)
        db = A'\collect(dc)
        # ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
        # `A'\dc` throws an error for such `dc`. The `collect()` here is a
        # workaround for this bug.

        dA = @thunk begin
            m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
            dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
            for j = 1:n, p = pp[j]:pp[j+1]-1
                dAv[p] = -db[ii[p]] * c[j]
            end
            dA = SparseMatrixCSC(m,n,pp,ii,dAv)
        end
        return (NO_FIELDS, dA, db)
    end
    return c, mybackslash_pullback
end

function ChainRulesCore.rrule(
    ::typeof(SparseArrays.SparseMatrixCSC),
    m::Integer, n::Integer,
    pp::Vector, ii::Vector, Av::Vector
)
    A = SparseMatrixCSC(m,n,pp,ii,Av)

    function SparseMatrixCSC_pullback(dA::AbstractMatrix)
        dAv = Vector{eltype(dA)}(undef, length(Av))
        for j = 1:n, p = pp[j]:pp[j+1]-1
            dAv[p] = dA[ii[p],j]
        end
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end
    function SparseMatrixCSC_pullback(dA::SparseMatrixCSC)
        @assert getproperty.(Ref(A), (:m,:n,:colptr,:rowval)) == getproperty.(Ref(dA), (:m,:n,:colptr,:rowval))
        return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dA.nzval)
    end

    return A, SparseMatrixCSC_pullback
end


using Test
using Zygote

function test()
    @testset "SparseMatrixCSC rrule" begin
        m,n = 3,2
        pp = [1,3,4]
        ii = [1,3,2]
         Av = collect(11.0:13.0)
        dAv = collect(21.0:23.0)

        dA_sparse = SparseMatrixCSC(m,n,pp,ii,dAv)
        dA_dense = [
            dAv[1]   0
              0    dAv[3]
            dAv[2]   0
        ]

        A,pb = rrule(SparseMatrixCSC, m,n,pp,ii,Av)
        @test A == SparseMatrixCSC(m,n,pp,ii,Av)
        @test pb(dA_dense ) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
        @test pb(dA_sparse) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
    end

    @testset "mybackslash" begin
        m,n = 3,3
        pp = [1,3,4,6]
        ii = [1,3,2,1,2]
        Av = collect(11.0:15.0)
         b = collect(21.0:23.0)
        dc = collect(31.0:33.0)

        As = SparseMatrixCSC(m,n,pp,ii,Av)
        Ad = Matrix(As)

        cd,pbd = rrule(mybackslash, Ad, b)
        _,dAd,dbd = unthunk.(pbd(dc))
        dAd[As .== 0] .= 0
        cs,pbs = rrule(mybackslash, As, b)
        _,dAs,dbs = unthunk.(pbs(dc))

        @test cd ≈ cs
        @test dAd ≈ dAs
        @test dbd ≈ dbs
    end

    @testset "Zygote integration" begin
        # This currently only checks that things compile

        function myfun(a)
            #A = [a[1] a[2]; a[3] a[4]]
            A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
            b = [1.0;1.0]
            x = mybackslash(A,b)
            return sum(x)
        end

        p = [6.0,1.0,1.0,3.0]
        @show g = gradient(myfun, p)
    end
end
4 Likes

Btw, the reason why some sparse-array pullbacks have not yet been implemented is because we have not yet decided what is the correct way to implement them, see https://github.com/JuliaDiff/ChainRules.jl/issues/232.

1 Like

Thank you! :smiley: I need some time… to digest this. I will follow up if I have question.

Fair enough. It took me about two hours to come up with this :wink: