Sparse jacobians of matrix models

You can see it from your script :slight_smile:
Just put 60x50 matrices and you will see that you get:

3000×3000 SparseArrays.SparseMatrixCSC{Float64, Int64} with 327000 stored entries:

so about 4 out of 100 entries are different from zero. The zygote call takes many seconds, but the norm computation in your #Test is very fast, because ForwardDiff is much faster.

More important than the sparsity of the matrix is the sparsity pattern. It could be that the sparsity pattern in this case does not enable SparseDiffTools to be faster than ForwardDiff. This is the pattern I am seeing:

⣿⣿⣿⣿⣿⣾⣾⣾⣾⣮⡻⣿⣿⣷⣷⣷⣷⣷⣝⢿⣿⣿⣾⣾⣾⣾⣮⡻⣻⣿⣷⣷⣷⣷⣷⣕⣝⢿⣿⣿⣾⣾⣾⣾⣮⡻⣿⣿⣷⣷⣷⣷⣷⠄
⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣾⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡫⣻⣿⣿⣿⣿⣿⣷⣵⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⠅
⣻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡫⣿⣿⣿⣿⣿⣿⣷⣵⢝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⡅
⣺⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡪⡿⣿⣿⣿⣿⣿⣷⣷⢝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⡇
⡺⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡺⡿⣿⣿⣿⣿⣿⣿⣗⢝⢿⣿⣿⣿⣿⣿⣿⣮⡃
⣿⣮⡺⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡺⡻⣿⣿⣿⣿⣿⣿⣗⣝⢿⣿⣿⣿⣿⣿⡂
⢿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣻⣿⣿⣿⣿⣿⣿⣕⣝⢿⣿⣿⣿⡂
⢽⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡫⣻⣿⣿⣿⣿⣿⣷⣵⣝⢿⣿⡇
⢽⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡪⣿⣿⣿⣿⣿⣿⣷⣵⢝⠇
⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡪⡿⣿⣿⣿⣿⣿⣷⠅
⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡺⡿⣿⣿⣿⣿⠅
⣺⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣵⣿⣿⣿⣿⣿⣿⣿⣯⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⡻⣿⣿⡅
⣺⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣻⡇
⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡃
⣿⣾⡮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⡂
⢽⣿⣿⣾⡮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡫⣿⣿⣿⣿⣿⣿⣿⣟⣽⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⡂
⢽⣿⣿⣿⣿⣿⡪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⡇
⢝⢿⣿⣿⣿⣿⣿⣯⣪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⠇
⣷⣝⢝⣿⣿⣿⣿⣿⣿⣯⣪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⠅
⣿⣿⣷⣝⢝⣿⣿⣿⣿⣿⣿⣮⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⠅
⣺⣿⣿⣿⣷⣕⢽⣿⣿⣿⣿⣿⣿⣾⡮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⡇
⣺⣿⣿⣿⣿⣿⣷⣕⢿⢿⣿⣿⣿⣿⣿⣾⡪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣪⡻⡇
⣮⡻⣿⣿⣿⣿⣿⣿⣷⣕⢿⢿⣿⣿⣿⣿⣿⣿⡪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡂
⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢟⢿⣿⣿⣿⣿⣿⣯⣪⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡂
⢽⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢝⣿⣿⣿⣿⣿⣿⣯⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡆
⢽⣿⣿⣿⣿⣿⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢝⣿⣿⣿⣿⣿⣿⣮⣮⡻⣿⣿⣿⣿⣿⣿⣷⣝⢿⣿⣿⣿⣿⣿⣿⣮⡻⣻⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇
⠙⠟⠟⠟⠟⠿⠿⠿⠮⠻⠻⠻⠻⠻⠿⠿⠷⠕⠝⠟⠟⠟⠟⠿⠿⠾⠮⠻⠻⠻⠻⠻⠿⠿⠷⠝⠟⠟⠟⠟⠿⠿⠿⠮⠻⠻⠻⠻⠻⠿⠿⠿⠿⠇

You might be already familiar with the theory but here is a very brief intro of it Sparse Jacobian or Hessian · Nonconvex.jl. There was also a lecture by Chris Rackauckas on YouTube about the same thing.

1 Like

to speed up an optimization problem

What’s the full formulation you’re trying to achieve?

Did you consider using JuMP?

@mohamed82008 that’s a plotting artifact. Here is the sparsity pattern:

The colours vector sparse_m.flat_f.jac_colors shows that the SparseDiffTools colouring algorithm couldn’t find any better splitting than to make each variable have its own colour. sparse_m.flat_f.jac_colors == 1:3000 is true. This is the worst case scenario for SparseDiffTools.

This info may not be that relevant and might also be common knowledge amongst the members here, but I’d like to point out that if you are targeting speed for sparse matrices do take a look at renumbering methods before doing solves. Something like running METIS’s nested dissection or SuiteSparse’s COLAMD/SYMAMD can go a long ways. (Maybe this is handled internally by the aforementioned packages and if not then I would argue/recommend that renumbering methods be incorporated) This may also only be applicable for solves/reducing fill-in. For reducing bandwidth/profile Cuthill-Mckee is used.

I don’t know if computing the jacobian itself for renumbered matrices would be faster or not. Just thought I’d chime in with something that could be helpful if the structure of the sparsity matters in the aforementioned tools.

1 Like

Thank you @odow , I didn’t consider Jump yet because I prefer reconstructing the optimization routine from scratch if I can. I tend to use Julia more to solve the “1st high level language is too slow” than the “2 languages problem”. I might in the future, the timescale of my quick experiment was too short to dig deep into something like Jump (I might be wrong).

Thank you @mohamed82008 . So we are saying that we should not expect to be able to extract a lot of extra speed-up using sparsity in this case, or it’s still worth double checking if one could run

forwarddiff_color_jacobian!(autoJac, s, dv, colorvec = colors)

without errors?

I call that function internally to compute the jacobian so it’s not needed on its own unless you suspect the correctness of the implementation which you are more than welcome to verify. I would be more inclined to try different colouring algorithms from SparseDiffTools though beside the column re-numbering or permutation described by @acxz but that should probably be part of the colour finding algorithm in SparseDiffTools if it’s not already there. Feel free to explore the SparseDiffTools implementation of colouring algorithms and possibly even improving them. The goal is to have as few colours as possible.

On a shorter timescale, any suggestions on matching JAX performance on the the same Jacobian computation?

As a reference point these are the results with JAX (CPU and GPU):

CPU

In [1]: import jax.numpy as jnp
   ...: from jax import grad, jit, vmap
   ...: from jax import random
   ...: from jax import jacfwd, jacrev
   ...:
   ...: key = random.PRNGKey(0)
   ...: a = random.normal(key,(60,60))
   ...: b = jnp.ones_like(a)
   ...: def m(a):
   ...:   return a - jnp.sum(a, 1, keepdims=True) @ jnp.sum(a, 0, keepdims=True)
   ...: jb = jit(jacrev(m))
   ...: jm = jit(jacfwd(m))
   ...: %timeit jb(a)
   ...: %timeit jm(a)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
15.7 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
36 ms ± 449 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

And GPU

In [3]: import jax.numpy as jnp
   ...: from jax import grad, jit, vmap
   ...: from jax import random
   ...: from jax import jacfwd, jacrev
   ...:
   ...: key = random.PRNGKey(0)
   ...: a = random.normal(key,(60,60))
   ...: b = jnp.ones_like(a)
   ...: def m(a):
   ...:   return a - jnp.sum(a, 1, keepdims=True) @ jnp.sum(a, 0, keepdims=True)
   ...: jb = jit(jacrev(m))
   ...: jm = jit(jacfwd(m))
   ...: %timeit jb(a)
   ...: %timeit jm(a)
287 µs ± 781 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
566 µs ± 47.3 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Forward diff is around 400ms on my machine, still 20 times slower than the JAX CPU version. It might be that JAX is using multi-threading automatically, but it would be nice to get close.

32bits Floats (I remembered it was less than 400ms, but still in the 200ms range)

PS Feel free to split the performance comparison with JAX into another thread if appropriate.

Yes, so @gianmariomanca this doesn’t have a sparsity pattern that sparse automatic differentiation can reduce. Sparse symbolic differentiation would be able to handle it well, so if you want to try using Symbolics.jl to generate it, that would be as fast as all hell.

It looks like the SparseDiffTools.jl README is just old. I’ll get that updated.

Using a global like that will slow codes down by more than 20x. Don’t use a global if you don’t need it. Also, for codes that mutate like that, Enzyme will be a much faster AD than Zygote (I’m surprised Zygote even worked?)

Note that only effects the cost of constructing the sparse Jacobian, and it only effects it if maximum(colors) is sufficiently smaller than length(s). As @mohamed82008 mentioned, maximum(colors) == length(s) in your case, so no sparsity simplification occurs here.

But more directly, @gianmariomanca you should post a profile of your code. I recommend using

to share an interactive flamegraph. The real question is, what is taking all of the time during the Jacobian construction?

1 Like

Thank you @ChrisRackauckas , I’ll definitely check the Symbolics approach at some point in the future, I’ve heard good things :slight_smile: Even though it might be not be trivial in this particular case.

Just to avoid misunderstanding: I agree about globals, but for clarity I didn’t benchmark or use the version of the function with the global variable dependence, just the basic function:

m(A) = A - sum(A, dims=2) * sum(A, dims=1)

I’ll try to share the profiling, but I’m not sure how, since attachments are not allowed and one needs to share the html or some other interactive format for the report to be useful.

Try

using NonconvexUtils, LinearAlgebra

# Preprocess

m(A) = A - sum(A, dims=2) * sum(A, dims=1)
x = rand(5, 5)
sym_m = symbolify(m, x, sparse = true)

# Compute the jacobian

x = rand(5, 5)
J = sym_m.flat_f.g(vec(x))

which uses Symbolics under the hood.

1 Like
julia> sym_m = symbolify(m, x, sparse = true)
ERROR: UndefVarError: symbolify not defined

Make sure you are on the latest NonconvexUtils and that no other loaded package exports symbolify. If 2 packages export the same name, you need to qualify the name by the package name e.g. NonconvexUtils.symbolify to tell Julia which symbolify you are talking about.

1 Like

Thank you. It does run on REPL with the latest packages. Seems to be doing fine for moderate sizes ~15x15, but I could not really run anything above ~20x20, which is quite far from ~60x60.

1 Like