# Speed comparison matrix multiplication in Julia

you know Jax is not single threaded right?:

``````In [3]: for x  in range(200):
...:     jnp.matmul(a,b)
...:
``````

Relative to what? Are you running BLAS single-threaded? You can check with

``````BLAS._get_num_threads()
``````

and set it to one thread with

``````BLAS.set_num_threads(1)
``````

As for orientation, there is a huge difference between row and column major here:

``````jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
170.893 ms (116 allocations: 30.52 MiB)
``````

vs

``````l> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
24.013 ms (116 allocations: 30.52 MiB)
``````
1 Like
``````import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", False)

import os

a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1

print(a.dtype) # int32

%%timeit
c = jnp.matmul(a,b) # 291ms
``````
``````In [1]: import jax.numpy as jnp
...: from jax.config import config
...: config.update("jax_enable_x64", False)
...:
...: import os
...:
...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1

In [3]: %timeit jnp.matmul(a,b)
1 s ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
``````
``````julia> function mm(A,B)
C = Array{Int32, 3}(undef, size(A,1),size(A,2), size(B,3))
for i in axes(A,1)
@views C[i,:,:] = A[i,:,:] * B[i,:,:]
end
C
end
mm (generic function with 1 method)

julia> a = reshape(Int32.(1:2*2000*400), 2,2000,400);

julia> b = reshape(Int32.(1:2*2000*400), 2,400,2000);

julia> using BenchmarkTools

julia> import LinearAlgebra

julia> @btime mm(\$a,\$b);
827.518 ms (18 allocations: 61.04 MiB)
``````

If you don’t correct for major dimension orientation, these comparisons won’t be very meaningful.

``````➜ julia -q                                                                                                                                                                      ~
julia> using BenchmarkTools, Tullio, LinearAlgebra

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
2.304 s (2 allocations: 30.52 MiB)
``````

You want: `2000, 400, 2`, also, Tullio doesn’t use BLAS

``````➜ jupyter console                                                                                                                                                               ~
Jupyter console 6.4.0

Python 3.7.10 (default, Apr 27 2021, 08:49:44)
IPython 7.25.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
...: from jax.config import config
...: config.update("jax_enable_x64", False)
...:
...: import os
...:
...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
...:
...: print(a.dtype) # int32
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
int32

In [2]: %%timeit
...: c = jnp.matmul(a,b)
...:
...:
395 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
``````

I am not sure why your Jax is slower

This is getting confusing.

Changing the number of BLAS threads should change the performance of the `mm` function, if you are comparing that to Tullio, it does not affect Tullio.

In order to compare single-threaded performance, you should

• Start Julia with a single thread
• set BLAS threads to 1
• set jax to use one thread
• Then compare all three

For multi-threading, do the same, but set up each to have the same number of threads as you have physical cores.

3 Likes

And, very importantly: change orientation of your arrays for Julia, since Julia is column major.

1 Like

it maybe due to CPU, can you check `top` and run it in a loop (like what I did) to make sure Jax is indeed using single thread? the fact that your single thread vs. multi only changed 5% is weird, on my end, I see 1s → 40ms difference with 48 cores.

And you really should use the correct orientation (i.e. memory layout) to be fair.

1 Like

I see your point. Seems like Jax is using multi-threading on my Mac. Same code and flags give the result you shared on Linux. This might be a bug in Jax on Mac.

1 Like

Updated the post, take a look. Thanks for the help

@sivakon I apologize for repeating myself, but why do you not fix the major axis orientation issue? Isn’t it correct that that jax (like numpy) uses row-major arrays? In that case these results are not comparable. Is there something I’m missing, since you’re not addressing this concern?

In the benchmark I posted further up, changing orientation leads to a 7x speedup on my computer.

1 Like

so the latest result suggested that Julia is not slower than Jax even when we’re using the awkward memory layout because col-major vs. row-major? good to know, because Jax is Numpy re-write with some TF runtime and LLVM backend that compiles kernel on its own

also you really should fix col-major vs. row-major to make it apple to apple

1 Like
``````using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
# 2.212 s (2 allocations: 30.52 MiB)

using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
# 1.854 s (2 allocations: 30.52 MiB)
``````

Is this what you mean? I don’t see any speed-up here.

That is strange. Here are my results:

BTW, did you use

``````using Tullio, LoopVectorization
``````

?

``````jl> Threads.nthreads()
1

jl> using Tullio, LoopVectorization

jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
713.923 ms (2 allocations: 30.52 MiB)

jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
133.608 ms (2 allocations: 30.52 MiB)
``````

``````jl> Threads.nthreads()
8

jl> using Tullio, LoopVectorization

jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

jl> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
154.384 ms (117 allocations: 30.52 MiB)

jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

jl> @btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
23.660 ms (117 allocations: 30.52 MiB)
``````

Mystery solved. Using just Tullio without LoopVectorization yields times around 1.8s. if I have

``````using Tullio, LoopVectorization
``````

it gives me a 2.5x speedup with row-major orientation, but a 13.5x speedup for colum major.

But anyway, I was just getting a bit exasperated that numerous requests to fix the orientation issue were not acknowledged or noticed.

2 Likes

Can confirm this:

``````julia> using Tullio, LoopVectorization, BenchmarkTools

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
1.415 s (2 allocations: 30.52 MiB)

julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

julia> @btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
251.543 ms (2 allocations: 30.52 MiB)
``````

Without LoopVectorization it takes 3 - 4 seconds with the latter case being only about 20% faster.

1 Like

I wanted to fix my issue first, have apples to apples comparison when comparing speed between Julia and Jax, and later optimize my Julia code.

``````julia> using Tullio, LoopVectorization, BenchmarkTools

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime @tullio c[i, j, k] := \$a[i, j, q] * \$b[i, q, k];
682.080 ms (2 allocations: 30.52 MiB)

julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));

julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));

julia> @btime @tullio c[j, k, i] := \$a[j, q, i] * \$b[q, k, i];
126.235 ms (2 allocations: 30.52 MiB)
``````

My code is hella fast now, thanks.