How do you speed up the linear sparse solver in Zygote?

I have confirmed that the standard Julia solver (backslash) can be used in the automatic differentiation of Zygote.
Below is a sample program. (This program calculates the gradient.)

However, while the forward computation of the evaluation function is fast, the computation of the automatic derivative is very slow. (Memory is not that large.)

What are you trying to do to speed up Zygote’s automatic differentiation?
Is it common to write your own rrule and have it calculate the derivative directly in units of large functions?

using Zygote
using ChainRules
using LinearAlgebra
using SparseArrays
function test_func(x)

    n = length(x)

    # make matrix
    A_buf = Zygote.Buffer(spzeros(Float64, n, n))
    k_tmp = rand(Float64, n, n)
    for i in 1 : n
        for j in 1 : n
            A_buf[i, j] += k_tmp[i, j] + x[j]^2.0

    # Zygote.Buffer -> SparseMatrixCSC
    A = copy(A_buf)

    # righ hand side
    b = fill(1.0 , n)

    # solve !!!!
    c = A \ b

    # compute evaluate function value
    return norm(c)
s = rand(Float64, 200)
@time value = test_func(s)

df_auto(x) = test_func'(x)
@time display(df_auto(s))
for i in 1 : n
  for j in 1 : n
    A_buf[i, j] += k_tmp[i, j] + x[j]^2.0

Doing the above for a sparse matrix is terribly inefficient for 3 reasons.

  • First, your matrix is not actually sparse if you are assigning non-zero values to all its elements. Using a sparse matrix data structure typically only makes sense when the ratio of non-zeros in your matrix is (much) less than 10%. Benchmark it.
  • Second, assigning a non-zero value to a structural zero in a sparse matrix is an O(N) operation where N is the number of existing non-zeros in the matrix. The number of existing non-zeros N before the assignment in your loop iterations go from 0 to n^2-1. This gives you a complexity of O(n^4).
  • Third, even if the above pitfalls are avoided by using a regular Matrix instead of a sparse one, Zygote is actually very bad at differentiating scalar code like the above where the number of function calls scales up with the input size. So if you want to use Zygote, you need to “vectorise” your code (like you would do in Matlab or Python) to reduce the number of function calls, preferably making it O(1), i.e. independent of the input size. Otherwise, look into Enzyme which in theory can handle code like this better. If you have to write scalar code with loops and Enzyme doesn’t work for you, you can wrap the loop-heavy part in a function and define an rrule for it to make sure Zygote doesn’t try to differentiate it, instead it will just use your rule.

Even if you use sparse matrix data structures etcetera, AD systems often need “help” (i.e. custom vJp/pullback rules) when differentiating functions that construct sparse matrices as an intermediate step. See e.g. this discussion: Zygote.jl: How to get the gradient of sparse matrix - #6 by stevengj

In practice, my group has always ended up writing custom rrules for such cases. Once you stray outside the confines of conventional ML functions (i.e. the usual neural-net building blocks) and get more into things like scientific computing, I find that for any sufficiently complicated calculation you eventually need to supplement AD systems with custom chain-rule steps.

Fortunately, once you understand the rules of “matrix calculus” (as in our MIT short course), it’s pretty straightforward to manually differentiate functions like your c(x) = \Vert A(x)^{-1} b \Vert where A(x) constructs a sparse matrix from x. If you supply that, AD systems can then handle propagating the chain rule through any calculations that c is composed with.


One of my students once commented that writing AD-able code is like writing code in another language that only superficially resembles Julia.


Sadly this is also my experience. Diffractor (Enzyme) was (is) supposed to change that is what I heard.

Replying to this just to do some expectation setting as a Enzyme dev:

While Enzyme is obviously a great tool (see my previous bias as an Enzyme dev), it does have limitations (which we try to be upfront/document, and if something is missing docs PRs are very welcome). We are intentionally starting at a smaller core of Julia (e.g. that compatible with GPUCompiler) and growing support for general Julia code as time goes on.

Presently, I think the biggest pain point for folks atm is that we don’t have complete support for type unstable Julia code. We’ve done a lot of work recently adding much more support here in the past few months, but still this is likely for someone to run into.

A previous pain point we’ve since mostly resolved, is now we have fast versions of most common BLAS functions (gemv, gemm, dot, etc), however we haven’t done this for LAPack yet (which Julia will call from det/etc).

A lot of this can be alleviated by folks who open issues on these, but especially those who can contribute (perhaps in a custom rule for what they are missing).

In 2024, I only have 2 weeks of my time funded for doing things related to Enzyme.jl, my primary funding is actually to support Enzyme in C++ and Enzyme in JaX. Being based around a common core, this still means this helps everyone, but it does mean that adding Julia-specific features (like more Julia type instable call support) are dependent on when Valentin/I have spare time. We do our best, but also this is why you’ll see Julia features come more on late nights/weekends than during the day.


I understand the idea behind this gripe, but I find that AD-able code in Julia tends to follow what I would consider to be good practices in Julia anyway. Avoiding mutation, by extension a functional code style preferring map instead of for to update out of place, aiming to be as type stable as you can.

The only other AD tools I’ve worked with have been in Python, but does JAX really feel better than Julia with respect to “coding in a language that only superficially resembles the original language”? Perhaps static languages have a leg up here, but I am not familiar with using reverse-mode AD in any static languages.

Yes and no. In Python, being forced to avoid writing your own loops and desperately hoping that there is a library routine that can be coaxed into doing what you want are “normal” situations for high-performance code.

Also, while JAX is a closed universe of AD-able functions, it is a very well-funded and extensive closed universe.


I still think the solution is to write your own custom rrules.
Would replacing them with custom rrules also reduce memory?
I am also struggling with the amount of memory used in automatic differentiation.

In fact, I am trying to apply it to automatic differentiation for finite element analysis (FEA).
So when I use it, it will be an appropriate sparse matrix.
However, what I noticed is that the linear solver part takes too much time, even at small sizes.
That is why I present this question.

There has been some discussion about the impact of applying automatic differentiation on Julia’s coding, so here are my thoughts.
I have created the same project with the following three types of projects. The project involves numerical simulations using the finite element method (FEM).

  1. Enzyme.jl
    I applied this package first.The fact that I can freely mutate the array without any special operations is appealing. I was able to realize what I wanted to do.
    However, one of Julia’s attractions, the linear algebra library, is not supported, and I created all the necessary content myself. In addition, the nesting of the structure also seemed to limit the depth at which it could be applied. Because it is difficult to create a high-performance linear algebra package from scratch, I am still waiting for an update to Enzyme.jl.

  2. jax (Python)
    I tried the second.
    However, programming directly the theory of numerical analysis requires a for loop.
    Python needs to avoid the for loop in terms of computational speed, so vectorization is necessary. I tried the second. However, programming directly the theory of numerical analysis requires a for loop.
    Python needs to avoid the for loop in terms of computational speed, so vectorization is necessary. I am not happy about that.
    In the end, I gave up using it in terms of computation speed and memory.
    I decided to use Julia, which can directly express the theory of numerical analysis.

  3. Zygote.jl
    I am currently at this stage.
    Originally, I didn’t use it because I found it difficult in that it didn’t allow for any mutation of the sequence.However, I managed to solve the problem by using the type Zygote.Buffer, so I used it. I think this is the most Julia-like program because of the extensive linear algebra library and the few restrictions on structures.

If Engyme.jl adds support for linear algebra, including sparse matrices to the same extent as Zygote.jl, it is possible that I will return to Engyme.jl.

Does Sparspak.jl + ForwardDiff work for you ?

Yes, in my experience often a custom rrule can be both faster and more memory-efficient than reverse-mode AD, assuming you know what you are doing in optimizing Julia code. Basically, you understand the structure of your program in a way that AD doesn’t, and you can exploit that understanding.

Did you try LinearSolve.jl? It adds rules for Zygote and Enzyme and wraps pretty much every linear solver that people use in practice. That should handle this just fine?


To compute \nabla_x \Vert A(x)^{-1} b \Vert with current AD tools where A(x) is sparse, you still have the problem of trying to differentiate through the sparse-matrix constructor, no?

Thank you. I understand. I will try to define my own rrule for the areas where I want to improve efficiency.

I tried small tests with Sparspak.jl + ForwardDiff.jl + ExtendableSparse.jl in the last week, but it didin’t work.
Maybe additional testing will be done to show the new topic.
Would it work well without Exrendable?

It should. Basically it should work as well with ExtendableSparse. Could you open an issue there (possibly with an MWE ?)

1 Like

I checked my tests and yes sparse is missing an adjoint. We have tests differentiating the b from Ax=b but not differentiating the A with respect to what it’s constructed using sparse without going through a dense A (i.e. we SparseMatrixCSC(A) has the right derivative overloads). Zygote handles the sparse matrix construction, though Enzyme does not as it hits a SuiteSparse barrier and needs a rule.

Enzyme largely does this, we use it all the time to effect! The main issues with Enzyme these days are LAPACK functions and SuiteSparse, and that of course is because it’s hitting non-Julia code and just needs rules to handle that boundary. A lot of this is handled in SciML tooling, though there’s a few a spots (NonlinearSolve.jl for example) which need more Enzyme rules where there currently exist ChainRules. So there’s a few edges right now but it ultimately feels solvable, unlike the Zygote/ChainRules issue which IMO is wrong at a fundamental level.

Most of the cases with garbage collection are handled. There’s still some edge cases with type uninferred code though.

You wouldn’t want to use forward-mode AD in cases where you have a scalar output and lots of input parameters.

1 Like