How do you speed up the linear sparse solver in Zygote?

Thinking some more about the example I posted above. In practice, one wouldn’t write inv in the code at all. And one can define a rule for the linear solve directly so u and I should be a single function call. This means that the following line can be optimized in the rrule of the linear solve (reusing the factorization, reusing the linear solve solution and doing the lazy 1 rank matrix representation).

dK=−(K^{−1} \cdot du) \cdot (K^{-1} \cdot f)^T

In fact, maybe all that’s needed is to make this line LinearSolve.jl/src/adjoint.jl at main · SciML/LinearSolve.jl · GitHub a lazy multiplication returning a 1-rank matrix instead of a dense one. @ChrisRackauckas would you be open to a PR for this?

Then one can define a rule for dxi=tr(dK^T \cdot K_i) or dot(A, B) in Julia where A is a 1-rank matrix and B is a SparseMatrixCSC. Then we can get near optimal performance in the above case using Zygote. Not sure if Enzyme rules support lazy arrays for the adjoint. This may be easier than I thought initially without too many changes.

2 Likes