How do you speed up the linear sparse solver in Zygote?

Even if you use sparse matrix data structures etcetera, AD systems often need “help” (i.e. custom vJp/pullback rules) when differentiating functions that construct sparse matrices as an intermediate step. See e.g. this discussion: Zygote.jl: How to get the gradient of sparse matrix - #6 by stevengj

In practice, my group has always ended up writing custom rrules for such cases. Once you stray outside the confines of conventional ML functions (i.e. the usual neural-net building blocks) and get more into things like scientific computing, I find that for any sufficiently complicated calculation you eventually need to supplement AD systems with custom chain-rule steps.

Fortunately, once you understand the rules of “matrix calculus” (as in our MIT short course), it’s pretty straightforward to manually differentiate functions like your c(x) = \Vert A(x)^{-1} b \Vert where A(x) constructs a sparse matrix from x. If you supply that, AD systems can then handle propagating the chain rule through any calculations that c is composed with.

2 Likes