Speed comparison matrix multiplication in Julia

you know Jax is not single threaded right?:

In [3]: for x  in range(200):
   ...:     jnp.matmul(a,b)


Relative to what? Are you running BLAS single-threaded? You can check with


and set it to one thread with


As for orientation, there is a huge difference between row and column major here:

jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  170.893 ms (116 allocations: 30.52 MiB)


l> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
  24.013 ms (116 allocations: 30.52 MiB)
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", False)

import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")

a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1

print(a.dtype) # int32

c = jnp.matmul(a,b) # 291ms
In [1]: import jax.numpy as jnp
   ...: from jax.config import config
   ...: config.update("jax_enable_x64", False)
   ...: import os
   ...: os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
   ...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
   ...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1

In [3]: %timeit jnp.matmul(a,b)
1 s ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
julia> function mm(A,B)
           C = Array{Int32, 3}(undef, size(A,1),size(A,2), size(B,3))
           for i in axes(A,1)
               @views C[i,:,:] = A[i,:,:] * B[i,:,:]
mm (generic function with 1 method)

julia> a = reshape(Int32.(1:2*2000*400), 2,2000,400);

julia> b = reshape(Int32.(1:2*2000*400), 2,400,2000);

julia> using BenchmarkTools

julia> import LinearAlgebra

julia> LinearAlgebra.BLAS.set_num_threads(1)

julia> @btime mm($a,$b);
  827.518 ms (18 allocations: 61.04 MiB)

Julia already faster

If you don’t correct for major dimension orientation, these comparisons won’t be very meaningful.

➜ julia -q                                                                                                                                                                      ~
julia> using BenchmarkTools, Tullio, LinearAlgebra

julia> BLAS.set_num_threads(1)

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  2.304 s (2 allocations: 30.52 MiB)

You want: 2000, 400, 2, also, Tullio doesn’t use BLAS

➜ jupyter console                                                                                                                                                               ~
Jupyter console 6.4.0

Python 3.7.10 (default, Apr 27 2021, 08:49:44)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.25.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
   ...: from jax.config import config
   ...: config.update("jax_enable_x64", False)
   ...: import os
   ...: os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
   ...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
   ...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
   ...: print(a.dtype) # int32
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [2]: %%timeit
   ...: c = jnp.matmul(a,b)
395 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I am not sure why your Jax is slower

This is getting confusing.

Changing the number of BLAS threads should change the performance of the mm function, if you are comparing that to Tullio, it does not affect Tullio.

In order to compare single-threaded performance, you should

  • Start Julia with a single thread
  • set BLAS threads to 1
  • set jax to use one thread
  • Then compare all three

For multi-threading, do the same, but set up each to have the same number of threads as you have physical cores.


And, very importantly: change orientation of your arrays for Julia, since Julia is column major.

it maybe due to CPU, can you check top and run it in a loop (like what I did) to make sure Jax is indeed using single thread? the fact that your single thread vs. multi only changed 5% is weird, on my end, I see 1s → 40ms difference with 48 cores.

And you really should use the correct orientation (i.e. memory layout) to be fair.

I see your point. Seems like Jax is using multi-threading on my Mac. Same code and flags give the result you shared on Linux. This might be a bug in Jax on Mac.

Updated the post, take a look. Thanks for the help

@sivakon I apologize for repeating myself, but why do you not fix the major axis orientation issue? Isn’t it correct that that jax (like numpy) uses row-major arrays? In that case these results are not comparable. Is there something I’m missing, since you’re not addressing this concern?

In the benchmark I posted further up, changing orientation leads to a 7x speedup on my computer.

so the latest result suggested that Julia is not slower than Jax even when we’re using the awkward memory layout because col-major vs. row-major? good to know, because Jax is Numpy re-write with some TF runtime and LLVM backend that compiles kernel on its own

also you really should fix col-major vs. row-major to make it apple to apple

using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k]; 
# 2.212 s (2 allocations: 30.52 MiB)

using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i]; 
# 1.854 s (2 allocations: 30.52 MiB)

Is this what you mean? I don’t see any speed-up here.

That is strange. Here are my results:

BTW, did you use

using Tullio, LoopVectorization



jl> Threads.nthreads()

jl> using Tullio, LoopVectorization

jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  713.923 ms (2 allocations: 30.52 MiB)

jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
  133.608 ms (2 allocations: 30.52 MiB)

8 threads:

jl> Threads.nthreads()

jl> using Tullio, LoopVectorization

jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  154.384 ms (117 allocations: 30.52 MiB)

jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
  23.660 ms (117 allocations: 30.52 MiB)

Mystery solved. Using just Tullio without LoopVectorization yields times around 1.8s. if I have

using Tullio, LoopVectorization

it gives me a 2.5x speedup with row-major orientation, but a 13.5x speedup for colum major.

But anyway, I was just getting a bit exasperated that numerous requests to fix the orientation issue were not acknowledged or noticed.


Can confirm this:

julia> using Tullio, LoopVectorization, BenchmarkTools

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  1.415 s (2 allocations: 30.52 MiB)

julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

julia> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
  251.543 ms (2 allocations: 30.52 MiB)

Without LoopVectorization it takes 3 - 4 seconds with the latter case being only about 20% faster.

I wanted to fix my issue first, have apples to apples comparison when comparing speed between Julia and Jax, and later optimize my Julia code.

julia> using Tullio, LoopVectorization, BenchmarkTools

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
  682.080 ms (2 allocations: 30.52 MiB)

julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

julia> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
  126.235 ms (2 allocations: 30.52 MiB)

My code is hella fast now, thanks.