The problem that you are left with is that Zygote doesn’t know how to pull back through A = sparse(i,j,v)
, and the solution is to provide this pullback. I believe the pullback for SparseMatrixCSC()
is easier than that for sparse()
, so what I did is to replace
A = sparse([1, 1, 2, 2], [1, 2, 1, 2], a)
with
A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
and define
function ChainRulesCore.rrule(
::typeof(SparseArrays.SparseMatrixCSC),
m::Integer, n::Integer,
pp::Vector, ii::Vector, Av::Vector
)
A = SparseMatrixCSC(m,n,pp,ii,Av)
function SparseMatrixCSC_pullback(dA)
# Pick out the entries in `dA` corresponding to nonzeros in `A`
dAv = Vector{eltype(dA)}(undef, length(Av))
for j = 1:n, p = pp[j]:pp[j+1]-1
dAv[p] = dA[ii[p],j]
end
return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
end
return A, SparseMatrixCSC_pullback
end
This should work, but it isn’t quite as efficient as it could be. The problem is that the pullback for mybackslash
computes a dense dA
even though the pullback for SparseMatrixCSC()
only looks at a subset of the entries of dA
, so to do better we should only compute those entries of dA
which are actually needed. This can be done as follows.
function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
c = A\b
function mybackslash_pullback(dc)
db = A'\collect(dc)
# ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
# `A'\dc` throws an error for such `dc`. The `collect()` here is a
# workaround for this bug.
dA = @thunk begin
m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
for j = 1:n, p = pp[j]:pp[j+1]-1
dAv[p] = -db[ii[p]] * c[j]
end
dA = SparseMatrixCSC(m,n,pp,ii,dAv)
end
return (NO_FIELDS, dA, db)
end
return c, mybackslash_pullback
end
Finally, you can optimise the SparseMatrixCSC()
pullback a little bit by picking the entries in a more efficient way if dA
is already a sparse matrix. I omit the details here and instead simply copy-paste the complete code.
using ChainRulesCore
using LinearAlgebra
using SparseArrays
mybackslash(A,B) = A\B
# For testing
function ChainRulesCore.rrule(::typeof(mybackslash), A, B)
C = A\B
function mybackslash_pullback(dC)
dB = A'\dC
return (NO_FIELDS, @thunk(-dB * C'), dB)
end
return C, mybackslash_pullback
end
function ChainRulesCore.rrule(::typeof(mybackslash), A::SparseMatrixCSC, b::AbstractVector)
c = A\b
function mybackslash_pullback(dc)
db = A'\collect(dc)
# ^ When called through Zygote, `dc` is a `FillArrays.Fill`, and
# `A'\dc` throws an error for such `dc`. The `collect()` here is a
# workaround for this bug.
dA = @thunk begin
m,n,pp,ii = A.m,A.n,A.colptr,A.rowval
dAv = Vector{typeof(zero(eltype(db)) * zero(eltype(dc)))}(undef, length(A.nzval))
for j = 1:n, p = pp[j]:pp[j+1]-1
dAv[p] = -db[ii[p]] * c[j]
end
dA = SparseMatrixCSC(m,n,pp,ii,dAv)
end
return (NO_FIELDS, dA, db)
end
return c, mybackslash_pullback
end
function ChainRulesCore.rrule(
::typeof(SparseArrays.SparseMatrixCSC),
m::Integer, n::Integer,
pp::Vector, ii::Vector, Av::Vector
)
A = SparseMatrixCSC(m,n,pp,ii,Av)
function SparseMatrixCSC_pullback(dA::AbstractMatrix)
dAv = Vector{eltype(dA)}(undef, length(Av))
for j = 1:n, p = pp[j]:pp[j+1]-1
dAv[p] = dA[ii[p],j]
end
return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
end
function SparseMatrixCSC_pullback(dA::SparseMatrixCSC)
@assert getproperty.(Ref(A), (:m,:n,:colptr,:rowval)) == getproperty.(Ref(dA), (:m,:n,:colptr,:rowval))
return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dA.nzval)
end
return A, SparseMatrixCSC_pullback
end
using Test
using Zygote
function test()
@testset "SparseMatrixCSC rrule" begin
m,n = 3,2
pp = [1,3,4]
ii = [1,3,2]
Av = collect(11.0:13.0)
dAv = collect(21.0:23.0)
dA_sparse = SparseMatrixCSC(m,n,pp,ii,dAv)
dA_dense = [
dAv[1] 0
0 dAv[3]
dAv[2] 0
]
A,pb = rrule(SparseMatrixCSC, m,n,pp,ii,Av)
@test A == SparseMatrixCSC(m,n,pp,ii,Av)
@test pb(dA_dense ) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
@test pb(dA_sparse) == (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist(), DoesNotExist(), dAv)
end
@testset "mybackslash" begin
m,n = 3,3
pp = [1,3,4,6]
ii = [1,3,2,1,2]
Av = collect(11.0:15.0)
b = collect(21.0:23.0)
dc = collect(31.0:33.0)
As = SparseMatrixCSC(m,n,pp,ii,Av)
Ad = Matrix(As)
cd,pbd = rrule(mybackslash, Ad, b)
_,dAd,dbd = unthunk.(pbd(dc))
dAd[As .== 0] .= 0
cs,pbs = rrule(mybackslash, As, b)
_,dAs,dbs = unthunk.(pbs(dc))
@test cd ≈ cs
@test dAd ≈ dAs
@test dbd ≈ dbs
end
@testset "Zygote integration" begin
# This currently only checks that things compile
function myfun(a)
#A = [a[1] a[2]; a[3] a[4]]
A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
b = [1.0;1.0]
x = mybackslash(A,b)
return sum(x)
end
p = [6.0,1.0,1.0,3.0]
@show g = gradient(myfun, p)
end
end