How do you speed up the linear sparse solver in Zygote?

I’d expand it to say, the C boundaries are what’s largely not covered beyond BLAS and SpecialFunctions.jl (which are specialized), so for “normal” Julia code that’s:

  • LAPACK
  • SuiteSparse
  • MPFR BigFloats
  • ARPACK
  • RMath (used in Distributions.jl for some of the distribution implementations)

None of that is fundamental though because it’s just missing rules, and the reason for it is because these are spots where the common Julia code solution is to call out to a C library.

And note that because Enzyme supports mutation, and any good type-stable function will work with Enzyme, the replacements to these things tend to work out of the box, though for some of these you do want rules.

As I explained in my linked post, it’s not simply a matter of a missing rule — if you don’t construct an implicit representation of the tangent vector upstream, the input to your adjoint/rrule/pullback/vJp will be a dense matrix, which sacrifices the point of sparsity.

It’s not that this can’t be done — I pointed out the required implicit representation in one common case in my post, and @mohamed82008 gave some example code. But do any current AD systems do this by default without manual intervention?

It requires manual intervention but maybe ImplicitDifferentiation.jl can help here. It seems constructing the sparse matrix A(x) makes sense for solving A(x) \ b optimally. However, the optimality condition is A(x) * y = b, and if you can write that one in a differentiable way (i.e. without the sparse constructor), you’re good to go

Note I just checked, this all is a non-issue with Enzyme as it differentiates fine:

using Enzyme, SparseArrays
function f1(A)
    sum(SparseMatrixCSC(A))
end

A = rand(2,2)
dA = zeros(2,2)
Enzyme.autodiff(Reverse, f1, Duplicated(A, dA))
@show dA
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

I = [1, 4, 3, 5]; J = [4, 7, 18, 9]; V = [1.0, 2, -5, 3];
S = sparse(I,J,V)
function f2(I,J,V)
    sum(sparse(I,J,V))
end

dV = zeros(4)
Enzyme.autodiff(Reverse, f2, I, J, Duplicated(V, dV))
@show dV

4-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
1 Like

How do you know it’s not constructing a dense matrix as an intermediate step?

No. The issue has nothing to do with \ (which doesn’t require implicit differentiation anyway). As I explained in my linked post, the same issue potentially arises with something as simple as f(x) = u^T A(x) v.

For reference, Zygote.jl: How to get the gradient of sparse matrix - #12 by mohamed82008

Actually, it looks good — I tried it with a {10}^6 \times {10}^6 sparse matrix, which would have run out of memory if it tried to allocate something dense, and it seems to work fine for my simple example:

using Enzyme, SparseArrays
const u = rand(10^6)
const v = rand(10^6)
g(I,J,x) = u' * sparse(I,J,x) * v
I = rand(1:10^6, 10^7)
J = rand(1:10^6, 10^7)
x = rand(10^7); dx = zero(x)
Enzyme.autodiff(Reverse, g, I, J, Duplicated(x, dx))

works — comparing to the analytical solution dx ≈ u[I] .* v[J] returns true.

So Enzyme seems to be doing the right thing by default these days, which is great news — how does Enzyme represent the tangent vector for input into the sparse adjoint?

2 Likes

I think a related issue is also what people mean when they want a gradient wrt to a sparse matrix. Do I mean the gradient wrt the whole matrix including its structural zeros which is the mathematically natural way to think about this? Or do I mean the gradient wrt to the matrix’s structural non-zeros? Chris’s example shows the latter which is perhaps easier to optimise in some cases if we don’t think of the sparse matrix data structure as a “mathematical matrix” at all and just try to differentiate the sum wrt the structural non-zeros. So the matrix here is just a small vector in disguise.

If you want the gradient wrt the whole matrix, this where it gets tricky. Because for different functions, there are various lazy representations that are more efficient to represent the gradient in. In the inner product example, a lazy outer product makes sense as the “lazy representation” of the gradient wrt the (sparse) matrix. The lazy representation, beside saving memory for simple functions like the inner product, can also allow for more efficient back-propagation in more complex composite functions. For example, if I have a function f(g(x)) where g returns a sparse matrix and f returns a scalar. If the pullback of f involves the inverse of g(x) but that inverse is then multiplied by a vector in the pullback of g, then constructing the inverse could have been avoided. This is an algorithmic difference in how the gradient of f(g(x)) is computed.

More theoretically for source-to-source AD, you can think of this as an advanced compiler optimisation pass that is linalg-aware and can use these linalg identities to optimise performance and eliminate intermediates, avoiding loss of structure as a by-product. Can an AD tool do this on its own without help from users for complex cases involving A \ b where A is sparse? In theory, maybe yes but it would need to have a linalg-aware optimization compiler pass to figure out useless or expensive intermediates (e.g. inverses of sparse matrices) and optimise them away. This might help with structured intermediates but perhaps not with the input itself being structured. Either way, I don’t think Enzyme is there (yet), but I don’t know.

If we can’t use Enzyme to efficiently differentiate through functions involving sparse linear algebra for whatever reason, then we have to rely on rules. Rules that are not lazy can easily ruin structure, sparsity or otherwise, as demonstrated in the post Steven linked to. So I think writing rules carefully so that they allow for lazy representations could be the best hope for doing reverse-mode AD on sparse linalg efficiently.

These are my 2 cents anyways. I don’t do sparse linalg autodiff for my job so this is all just a fun hobby for me.

1 Like

I know, but the point is to leverage A being a sparse matrix when we do A(x) \ b, because we have efficient routines for that. So there is no way to circumvent the sparse constructor in the solver itself. However, I was thinking the optimaliy condition A(x) * y = b could be written efficiently without generating a sparse matrix object, which opens the door to implicit differentiation. Indeed, the interesting case is when the solver is not differentiable but the conditions are

Related discussion:

Perhaps a good first step could be to go through ChainRulesCore and ChainRules and replace all inv(A) and matrix factorisations with lazy versions, optionally instantiating the inverse/factorization only before returning from gradient but not in the pullback. This simple change might help with some real life cases already. Any outer product should also be lazy by default. Lazy rules may also be done in Enzyme (if possible) to get the best of both worlds.

The realistic case of interest is that you have a scalar function f(A(x)) of a sparse matrix A(x) that is constructed from some parameters x, and you want the gradient of f with respect to x. In order to pass the chain rule through the A(x) constructor, you only need to differentiate with respect to the structural nonzeros.

In principle, this is no problem. The implementation is the tricky part, because the representation of the tangent vector passed into the sparse pullback has to either know about the sparsity pattern or it has to exploit some other structure (e.g. the rank-1 structure in my example). It seems that Enzyme has solved this, at least in some cases, which I’d love to hear more about.

If we can’t use Enzyme to efficiently differentiate through functions involving sparse linear algebra for whatever reason, then we have to rely on rules.

It’s always good to know how to write your own vJp/pullback/rrule chain-rule steps, because for any sufficiently complicated problem you are likely to run into a case that AD either doesn’t handle or handles poorly. Especially if you are doing research that isn’t a cookie-cutter rearrangement of the same old ML building blocks.

It differentiates the steps, and the steps only require shadow data of like form. So if the algorithm never constructs a matrix it won’t construct a matrix. This “structural derivative” does have some interesting properties to be aware of, see Supporting covariant derivatives · Issue #1334 · EnzymeAD/Enzyme.jl · GitHub, but it at least has the guarantee that it’s not going to require more than linear memory beyond the primal (except for any special rule of course).

At a high level, I think there’s effectively 4 different avenues in which progress is being made and needs to be continued. These are:

  1. Improving what the base AD system is able to support in terms of fundamental programming language features. This has largely stopped with Zygote about 4 years ago, but Enzyme is pushing this with its mutation support, its GC support now robust, and rudimentary type-instability support that continues to improve.
  2. Improving the rules to “core” standard library functions. This is the functions like sparse, etc. but also SpecialFunctions and the like, to be as well-covered and fast as possible.
  3. Increasing the coverage of the standard library. This is putting rules on more fringe parts of the standard library that call core pieces like Tridiagonal, which in theory should be covered by the core but are done just to avoid mutation or some performance issue.
  4. High level library overloads. This is LinearSolve.jl, NonlinearSolve.jl, DifferentialEquations.jl, etc.

But the interesting thing is that doing this in the most efficient way decreases in difficulty as you go down the ladder because you have more problem information to specialize on.

My point is that you could try to work on (2) and (3) very hard with lazy arrays and all sorts of tricks, which is a lot of what goes on in the ChainRules/Zygote land, ultimately it’s hard to get that very optimal. It’s somewhat easier to just make sure the actual use cases are handled (4), and then making sure the setup codes around it (i.e. the things using mutation to build a sparse matrix) are handled well enough, the (1). That’s the Enzyme+SciML approach we’ve been doing over the last few years.

This is why I have more and more been going down the direction of making the level (4) be as covered as possible, automating things like implicit rules that you couldn’t expect AD to do (also, numerical stability, which is a whole story since AD is not necessarily numerically stable on things like ODE solvers), while Enzyme focuses on (1) and we slowly get more coverage but also special handling of high level mathematical bits. If you call a solver, all of the tricks are done, and the rest is handled by indexing and mutation. There are some edge cases there like BLAS which are being knocked out one by one, but the standard library of Julia doesn’t have all that much C code to handle.

Doing more FillArrays etc. with Zygote/ChainRules is trying to solve it by (2), (3), and (4) simultaneously, and ultimately it’s never bad to keep including better rules at this level, though I personally don’t think it’s the ultimate solution as there’s way too much to handle. At some point, sparse should just work because you support mutation, not because you’ve rewritten the whole standard library!

What I mean to say is, I personally think a lot of this discussion is going down the wrong road. The problem with the OP is most likely that they shouldn’t be calling \ directly in the first place. If it’s for a Newton method, BFGS, PDE solve, etc. then there’s likely a level (4) rule they would benefit from. LinearSolve.jl, NonlinearSolve.jl, DifferentialEquations.jl, … one of those pieces + Enzyme for the routine of building the sparse matrix and you’re done. That would normally get you more optimal because then you get the line search methods, implicit differentiation, and all of the other mathematical tricks for free, but it would also just be natural Julia code. I showed that if you just put the sparse call in an Enzyme call, and put LinearSolve.jl in an Enzyme call (LinearSolve.jl/test/enzyme.jl at main · SciML/LinearSolve.jl · GitHub) then you’re fine. So OP has a solution: just use LinearSolve.jl!

There are still edges of course, we’re not done with all of it (in particular I need to find time for Fix Enzyme integration following minimal example · Issue #479 · SciML/LinearSolve.jl · GitHub), but those issues are all solvable and there aren’t really hard edges. The one hard edge is that your code should probably be type-stable for Enzyme to work well right now: it can handle type-unstable codes but there are some cases that aren’t handled so the easiest advice is just make it type stable. But other than that and Supporting covariant derivatives · Issue #1334 · EnzymeAD/Enzyme.jl · GitHub, things are looking pretty good for this domain.

2 Likes

I do, we use it daily and have tests on it. Sparse direct and Krylov methods. That’s one of the big reasons for LinearSolve.jl. All of them are solving linear systems. You don’t need to put a rule on SuiteSparse UMPACK, then KLU, then Krylov.jl, then Pardiso. You just make them all one interface, LinearSolve.jl, define the adjoint for a linear solver (which we did for ChainRules (Zygote/Diffractor) and Enzyme) and now all of those are supported.

Then we did that to nonlinear solvers, ODEs, etc. That’s why SciML is built like that, it’s the right level of abstraction to tie differentiation to abstract mathematical problems not to specific solvers.

1 Like

Some of this was touched in the issues/related docs, but I’ll give some more sparse details here…which if anyone want to help make into docs, would be fantastic.

We actually talked about this sort of thing a few weeks ago in the Enzyme weekly open meetings (also shameless plug to attend EnzymeCon next week at MIT: Enzyme Conference 2024)

Essentially the way Enzyme represents all data structures, arrays, linked lists, trees, and as relevant to this discussion, sparse data structures, is to have the shadow (aka derivative) memory be the same memory layout as the primal. Suppose you have an input data structure x. The derivative of x at byte offset 12 will be stored in the shadow dx at byte offset 12, etc.

This has the nice property that the storage for the derivative, including all intermediate computations, is the same as that of the primal (ignoring caching requirements for reverse mode).

It also means that any arbitrary data structure can be differentiated with respect to, and we don’t have any special handling required to register every data structure one could create.

This representation does have some caveats (e.g. see Why does `Duplicated(x, dx)` assume `x` and `dx` have the same type? · Issue #1329 · EnzymeAD/Enzyme.jl · GitHub and Supporting covariant derivatives · Issue #1334 · EnzymeAD/Enzyme.jl · GitHub and Home · Enzyme.jl for relevant deep dives), but I’ll discuss the relevant parts to sparsity below.

Sparse data structures are often (and I believe in Julia) represented with say a Vector{Float64} that holds the actual elements, and a Vector{Int} that specifies the index n the backing array that corresponds to the true location in the overall vector.

We have no explicit special cases for Sparse Data structures, so the layout semantics mentioned above is indeed what Enzyme uses.

Thus the derivative of a sparse array is to have a second backing array of the same size, and another Vector{Int} (of the same offsets).

As a concrete example, suppose we have the following:
x = { 3 : 2.7, 10 : 3.14 }, which is to say a sparse data structure with two elements, one at index 3, another at index 10. This could be represented with the backing array being [2.7, 3.14] and the index array being [3, 10]

A correctly zero-initialized shadow data structure would be to have a backing array of size 2 with zero’s, and an index array again being [3, 10].

In this form the second element of the derivative backing array is used to store/represent the derivative of the second element of the original backing array, in other words the derivative at index 10.

A caveat here (discussed in our FAQs here: Home · Enzyme.jl) is that this correctly zero’d initializer is not the default produced by sparse(0). Instead we provide and generally recommend using Enzyme.make_zero which recursively goes through your data structure to generate the shadows of the correct structure (and in this case would make a new backing array). Again the make_zero function is not special cased to sparsity at all, but just comes out as a result.

Internally, when differentiating a function this is the type of data structure that Enzyme builds and uses to represent variables. It is at the julia level that there’s a bit of a sharp edge.

@stevengj in your example f(A(x)) where A returns a sparse data structure. Enzyme’s differentiable version of A that it generates would create both the backing/index arrays for the original result A, as well as the equal sized backing/index arrays for the derivative.

For any program which generates sparse data structures internally, this will always give you the answer you expect – and with the corresponding memory requirements outlined above.

The added caveat, however, is if you want to differentiate a top level function that takes in a sparse array. For example consider the sum function over all elements. While in one semantic meaning it is meant to represent summing up all elements of the virtual sparse array, in a more literal sense the sum will only add elements 3 and 10 of the input sparse array – the only two nonzero elements – or equivalently the sum of the whole backing array. Correspondingly Enzyme will update the sparse shadow data structure to mark both elements 3 and 10 as having a derivative of 1 (or more literally set all the elements of the backing array to derivative 1). These are the only variables that Enzyme needs to update, since quite literarily they are the only variables read (and thus have a non-zero derivative). This is why this memory-safe representation composes within Enzyme.

If the name we gave to this data structure wasn’t “SparseArray” but instead “MyStruct” this is precisely the answer we would have desired. However, since the sparse array printer would otherwise print zeros for elements outside of the sparse backing array, this isn’t what one would expect. Making a nicer user conversion from Enzyme’s form of differential data structures, to the more natural “Julian” form where there is a semantic mismatch between what Julia intends a data structure to mean by name, and what is implemented is going on here (Supporting covariant derivatives · Issue #1334 · EnzymeAD/Enzyme.jl · GitHub).

The benefit of this representation, however, is that all of our rules compose correctly (e.g. how you got the correct answer for f(A(x)), without the need to special case any sparse code, and with the same memory/performance expectations as the original code.

Another quick caveat: like Chris said, if Julia calls an external library like SuiteSparse that we don’t have a derivative for, Enzyme will throw a “no derivative found” error for that function. This is easily resolvable by writing a custom rule or internal Enzyme support for whatever function arises. This is admittedly limited on SuiteSparse at the moment as we’ve been focusing on other areas of expanding the scope of Enzyme.jl’s julia language support (moreso on type unstable code), but all contributions very welcome.

4 Likes

Having more cookie cutters is useful but it doesn’t solve every problem out there. A language-wide AD can in theory do better on its own. Consider this case which @stevengj and others are probably very familiar with from topology optimisation:

\begin{aligned} K(x) &= K_0 + \sum_{i=1}^N x_i \times K_i \\ I(K) &= K^{-1} \\ u(I) & = I \cdot f \\ c(u) & = u^T \cdot f \end{aligned}

where

  • x, f and u are vectors of sizes O(N),
  • K is a large invertible symmetric positive definite matrix of size O(N) \times O(N),
  • K_i for each i is a symmetric hyper-sparse matrix with O(1) non-zeros, and
  • c is a scalar.

The back-propagation written manually is:

\begin{aligned} dc & = 1 \\ du & = dc \cdot f \\ dI & = du \cdot f^T \\ dK & = -K^{-1} \cdot dI \cdot K^{-1} \\ dx_i & = tr(dK^T \cdot K_i) \end{aligned}

Computing the gradient of c wrt x the naive way would be inefficient for 2 main reasons:

  1. dI is an outer product of 2 vectors so it shouldn’t require more than O(N) memory
  2. Computing dK requires the inverse of K and 2 matrix-matrix multiplications

Now let’s assume that dI and dK are combined because dI was only lazily computed, we get:

\begin{aligned} dc & = 1 \\ du & = dc \cdot f \\ dK & = -K^{-1} \cdot (du \cdot f^T) \cdot K^{-1} \\ dx_i & = tr(dK^T \cdot K_i) \end{aligned}

The bottleneck is clearly the line:

dK = -K^{-1} \cdot (du \cdot f^T) \cdot K^{-1}

If a lazy linear algebra library is used, this can be re-written as

dK = -(K^{-1} \cdot du) \cdot (K^{-1} \cdot f)^T

which requires a single factorisation of K, 2 linear system solves and an outer product. But we can do better. K^{-1} \cdot f was computed in the forward pass as u so it shouldn’t be re-computed. Also K was factorised in the forward-pass so it shouldn’t be re-factorised.

dK = -(K^{-1} \cdot du) \cdot u^T

Next we should do the outer product lazily to avoid constructing a large dense matrix for dK.

Because dK requires O(N) memory to store and K_i is hyper-sparse with O(1) non-zeros, computing dx_i is O(1) for each i and O(N) for all of x. This is actually the optimal way to propagate derivatives in this case.

So to automate the above, we need the following:

  • Lazy inverse of a matrix
  • Lazy matrix chain multiplication with re-ordering of multiplication operations
  • Linear solve and factorisation memoization (or common sub-expression elimination) to realise that u and K's factorisation can be re-used from the forward-pass
  • Lazy outer product of 2 vectors

The tricky part is figuring out when to be lazy, when to be eager and when to memoise. In theory, getting a near optimal performance from any reverse-mode AD system (Zygote/ReverseDiff/Enzyme) should be possible in this case with all the optimisations above in place, without having to define an efficient rule for c(u(I(K(x)))), which is what I do in TopOpt.jl, adding a new “cookie cutter” to the collection. A new rule is needed in this case because the AD system can’t do the above optimisations on its own because rules were not defined in a lazy way that enables inter-rule optimisations. Notice that this issue is not just a “tangent representation issue”, it’s a bigger issue of how to fuse rules together and eliminate intermediates. The tangent representation issue is relevant and may have been a reason why some AD systems are not able to use lazy rules, but it’s not the only issue here.

The above optimisations are analogous to asking Julia to be smart in the following non-AD case by:

  • Re-arranging parentheses,
  • Not inverting K and doing its factorisation once, and
  • Storing A as a lazy 1-rank matrix
A = inv(K) * (f * f')
u = inv(K) * f
1 Like

I don’t think that’s the fault of an AD system, but that Julia’s optimizer cannot perform those optimizations.

Enzyme’s raison d’etre isn’t to be the one AD to rule them all, but rather to study the impact of optimizations on automatic differentiation.

The Enzyme hypothesis (which has been subsequently validated in numerous papers since) is that performing optimization prior to differentiation yields much faster code than differentiation coming first (and has since also been extended to interpolating differentiation with optimization).

I think the issue here is that Julia doesn’t support such linear algebra optimizations at the moment, and it would very much benefit from doing so — including for regular code not being differentiated.

For example, in the Enzyme-JaX project (which can use Enzyme for AD for JaX source code), we specifically have linear algebra level optimizations (including those we have written ourselves). These optimizations both speed up regular JaX code, but again their impact is amplified when combined with differentiation.

1 Like

I agree. The issue is more general than Enzyme vs Zygote vs ReverseDiff. It’s just that in this kind of functions, making sure the forward-pass is optimised does not guarantee that the reverse-pass will be optimal. We can infuse laziness in the forward-pass to optimise it but optimising the reverse-pass requires either linalg-aware compiler optimisations with source-to-source AD or strategically using “lazy rules”.

the primary missing optimization for Julia right now is a lack of cse and loop hoisting. to fix these, we need to tell LLVM about the effect analysis results that we already track. Gabriel had a PR to do so, but it’s fairly stalled at the moment.