Hi there, I am trying to differentiate through sparse matrix, and here is the MWE:
using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote
Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
function test1(a)
A = sparse([1, 1, 2, 2],[1, 2, 1, 2],a, 2, 2)
return sum(A)
end
a = [1.0,2.0,3.0,4.0]
gradient(test1, a)
Then I got error:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#399#400")(::Nothing) at /root/.julia/packages/Zygote/CgsVi/src/lib/array.jl:58
[3] (::Zygote.var"#2265#back#401"{Zygote.var"#399#400"})(::Nothing) at /root/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] sparse! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:862 [inlined]
[5] (::typeof(∂(sparse!)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[6] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:703 [inlined]
[7] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[8] sparse at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/sparsematrix.jl:892 [inlined]
[9] (::typeof(∂(sparse)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[10] test1 at /root/codes/test_zygote/test_sparse.jl:11 [inlined]
[11] (::typeof(∂(test1)))(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
[12] (::Zygote.var"#41#42"{typeof(∂(test1))})(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:41
[13] gradient(::Function, ::Array{Float64,1}) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:59
[14] top-level scope at /root/codes/test_zygote/test_sparse.jl:16
in expression starting at /root/codes/test_zygote/test_sparse.jl:16
Then I changed sparse to SparseMatrixCSC, it works:
using LinearAlgebra
using SparseArrays
using ChainRulesCore
using Zygote
Zygote.@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end
function test2(a)
A = SparseMatrixCSC(2,2,[1, 3, 5], [1, 2, 1, 2],a)
return sum(A)
end
a = [1.0,2.0,3.0,4.0]
gradient(test2, a)
output
([1.0, 1.0, 1.0, 1.0],)
As far as I am concerned, calling sparse will also create SparseMatrixCSC, so my questions are:
- I know when two indices are the same in the input index array, calling
sparsewill add the values of duplicated entries, is this the root cause of “Mutating arrays is not supported” when callingsparse? -
How can I directly differentiate
sparse? Because I know callingSparseMatrixCSCrather than the exposed APIsparseis not encouraged.
Thx for any reply.