Thinking some more about the example I posted above. In practice, one wouldn’t write inv in the code at all. And one can define a rule for the linear solve directly so u and I should be a single function call. This means that the following line can be optimized in the rrule of the linear solve (reusing the factorization, reusing the linear solve solution and doing the lazy 1 rank matrix representation).
Then one can define a rule for dxi=tr(dK^T \cdot K_i) or dot(A, B) in Julia where A is a 1-rank matrix and B is a SparseMatrixCSC. Then we can get near optimal performance in the above case using Zygote. Not sure if Enzyme rules support lazy arrays for the adjoint. This may be easier than I thought initially without too many changes.
The PR above just got merged which means that in the next version of LinearSolve.jl if you use Zygote.jl (or any ChainRules.jl based AD package) to get the gradient of dot(a, A \ b) wrt the matrix A, you will get a lazy outer product. This is true when A is dense.
using LinearAlgebra, LinearSolve, Zygote
function invquad(a, A, b)
prob = LinearProblem(A, b)
sol = solve(
prob,
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
)
return dot(a, sol.u)
end
n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);
db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);
Base.summarysize(dA)
# 1752
Base.summarysize(A)
# 80040
If A is sparse, Zygote.pullback gives you the correct lazy outer product but Zygote.gradient projects that to be a sparse matrix with the same sparsity structure as the input which is arguably incorrect (Gradient wrt to a sparse matrix is mathematically wrong · Issue #1507 · FluxML/Zygote.jl · GitHub) because the gradient is mathematically defined for the structural zeros. The fact that they are structural zeros is an implementation detail. The projection imo should be done by the user after the gradient call if needed. Anyways, if you use pullback directly, you are safe.
function sinvquad(a, A, b)
prob = LinearProblem(A, b)
sol = solve(prob)
return dot(a, sol.u)
end
sA = sparse(A)
dsA = Zygote.pullback(sinvquad, b1, sA, b2)[2](1.0)[2]
Base.summarysize(dsA)
# 1752
Base.summarysize(sA)
# 160968