Hi there, I am trying to differentiate through sparse matrix, and here is the MWE:
using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote
Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
function test1(a)
A = sparse([1, 1, 2, 2],[1, 2, 1, 2],a, 2, 2)
return sum(A)
end
a = [1.0,2.0,3.0,4.0]
gradient(test1, a)
Then I got error:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#399#400")(::Nothing) at /root/.julia/packages/Zygote/CgsVi/src/lib/array.jl:58
[3] (::Zygote.var"#2265#back#401"{Zygote.var"#399#400"})(::Nothing) at /root/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] sparse! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:862 [inlined]
[5] (::typeof(∂(sparse!)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[6] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:703 [inlined]
[7] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[8] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:892 [inlined]
[9] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[10] test1 at /root/codes/test_zygote/test_sparse.jl:11 [inlined]
[11] (::typeof(∂(test1)))(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[12] (::Zygote.var"#41#42"{typeof(∂(test1))})(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:41
[13] gradient(::Function, ::Array{Float64,1}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:59
[14] top-level scope at /root/codes/test_zygote/test_sparse.jl:16
in expression starting at /root/codes/test_zygote/test_sparse.jl:16
Then I changed sparse
to SparseMatrixCSC
, it works:
using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote
Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
function test2(a)
A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
return sum(A)
end
a = [1.0,2.0,3.0,4.0]
gradient(test2, a)
output
([1.0, 1.0, 1.0, 1.0],)
As far as I am concerned, calling sparse
will also create SparseMatrixCSC
, so my questions are:
- I know when two indices are the same in the input index array, calling
sparse
will add the values of duplicated entries, is this the root cause of “Mutating arrays is not supported” when callingsparse
? -
How can I directly differentiate
sparse
? Because I know callingSparseMatrixCSC
rather than the exposed APIsparse
is not encouraged.
Thx for any reply.