How do you speed up the linear sparse solver in Zygote?

Note I just checked, this all is a non-issue with Enzyme as it differentiates fine:

using Enzyme, SparseArrays
function f1(A)
    sum(SparseMatrixCSC(A))
end

A = rand(2,2)
dA = zeros(2,2)
Enzyme.autodiff(Reverse, f1, Duplicated(A, dA))
@show dA
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

I = [1, 4, 3, 5]; J = [4, 7, 18, 9]; V = [1.0, 2, -5, 3];
S = sparse(I,J,V)
function f2(I,J,V)
    sum(sparse(I,J,V))
end

dV = zeros(4)
Enzyme.autodiff(Reverse, f2, I, J, Duplicated(V, dV))
@show dV

4-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
1 Like