How do you speed up the linear sparse solver in Zygote?

Thinking some more about the example I posted above. In practice, one wouldn’t write inv in the code at all. And one can define a rule for the linear solve directly so u and I should be a single function call. This means that the following line can be optimized in the rrule of the linear solve (reusing the factorization, reusing the linear solve solution and doing the lazy 1 rank matrix representation).

dK=−(K^{−1} \cdot du) \cdot (K^{-1} \cdot f)^T

In fact, maybe all that’s needed is to make this line LinearSolve.jl/src/adjoint.jl at main · SciML/LinearSolve.jl · GitHub a lazy multiplication returning a 1-rank matrix instead of a dense one. @ChrisRackauckas would you be open to a PR for this?

Then one can define a rule for dxi=tr(dK^T \cdot K_i) or dot(A, B) in Julia where A is a 1-rank matrix and B is a SparseMatrixCSC. Then we can get near optimal performance in the above case using Zygote. Not sure if Enzyme rules support lazy arrays for the adjoint. This may be easier than I thought initially without too many changes.

2 Likes

I’d be open to a PR for this. I don’t know what would happen if it’s not matching the type of A but it’s worth investigating.

1 Like
1 Like

Thank you! I will refer to it.

The PR above just got merged which means that in the next version of LinearSolve.jl if you use Zygote.jl (or any ChainRules.jl based AD package) to get the gradient of dot(a, A \ b) wrt the matrix A, you will get a lazy outer product. This is true when A is dense.

using LinearAlgebra, LinearSolve, Zygote

function invquad(a, A, b)
    prob = LinearProblem(A, b)
    sol = solve(
        prob,
        LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
    )
    return dot(a, sol.u)
end

n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);

db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);

Base.summarysize(dA)
# 1752

Base.summarysize(A)
# 80040

If A is sparse, Zygote.pullback gives you the correct lazy outer product but Zygote.gradient projects that to be a sparse matrix with the same sparsity structure as the input which is arguably incorrect (Gradient wrt to a sparse matrix is mathematically wrong · Issue #1507 · FluxML/Zygote.jl · GitHub) because the gradient is mathematically defined for the structural zeros. The fact that they are structural zeros is an implementation detail. The projection imo should be done by the user after the gradient call if needed. Anyways, if you use pullback directly, you are safe.

function sinvquad(a, A, b)
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return dot(a, sol.u)
end

sA = sparse(A)
dsA = Zygote.pullback(sinvquad, b1, sA, b2)[2](1.0)[2]

Base.summarysize(dsA)
# 1752

Base.summarysize(sA)
# 160968

I think the above PR was the first step towards preserving structure in rules and using lazy representations where possible. The next step is Make the rrule for 3-arg dot lazy · Issue #788 · JuliaDiff/ChainRules.jl · GitHub.

2 Likes