How do you speed up the linear sparse solver in Zygote?

The PR above just got merged which means that in the next version of LinearSolve.jl if you use Zygote.jl (or any ChainRules.jl based AD package) to get the gradient of dot(a, A \ b) wrt the matrix A, you will get a lazy outer product. This is true when A is dense.

using LinearAlgebra, LinearSolve, Zygote

function invquad(a, A, b)
    prob = LinearProblem(A, b)
    sol = solve(
        prob,
        LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
    )
    return dot(a, sol.u)
end

n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);

db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);

Base.summarysize(dA)
# 1752

Base.summarysize(A)
# 80040

If A is sparse, Zygote.pullback gives you the correct lazy outer product but Zygote.gradient projects that to be a sparse matrix with the same sparsity structure as the input which is arguably incorrect (Gradient wrt to a sparse matrix is mathematically wrong · Issue #1507 · FluxML/Zygote.jl · GitHub) because the gradient is mathematically defined for the structural zeros. The fact that they are structural zeros is an implementation detail. The projection imo should be done by the user after the gradient call if needed. Anyways, if you use pullback directly, you are safe.

function sinvquad(a, A, b)
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return dot(a, sol.u)
end

sA = sparse(A)
dsA = Zygote.pullback(sinvquad, b1, sA, b2)[2](1.0)[2]

Base.summarysize(dsA)
# 1752

Base.summarysize(sA)
# 160968

I think the above PR was the first step towards preserving structure in rules and using lazy representations where possible. The next step is Make the rrule for 3-arg dot lazy · Issue #788 · JuliaDiff/ChainRules.jl · GitHub.

2 Likes