you know Jax is not single threaded right?:
In [3]: for x in range(200):
...: jnp.matmul(a,b)
...:
you know Jax is not single threaded right?:
In [3]: for x in range(200):
...: jnp.matmul(a,b)
...:
Relative to what? Are you running BLAS single-threaded? You can check with
BLAS._get_num_threads()
and set it to one thread with
BLAS.set_num_threads(1)
As for orientation, there is a huge difference between row and column major here:
jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
170.893 ms (116 allocations: 30.52 MiB)
vs
l> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
24.013 ms (116 allocations: 30.52 MiB)
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", False)
import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
print(a.dtype) # int32
%%timeit
c = jnp.matmul(a,b) # 291ms
In [1]: import jax.numpy as jnp
...: from jax.config import config
...: config.update("jax_enable_x64", False)
...:
...: import os
...: os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
...:
...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
In [3]: %timeit jnp.matmul(a,b)
1 s ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
julia> function mm(A,B)
C = Array{Int32, 3}(undef, size(A,1),size(A,2), size(B,3))
for i in axes(A,1)
@views C[i,:,:] = A[i,:,:] * B[i,:,:]
end
C
end
mm (generic function with 1 method)
julia> a = reshape(Int32.(1:2*2000*400), 2,2000,400);
julia> b = reshape(Int32.(1:2*2000*400), 2,400,2000);
julia> using BenchmarkTools
julia> import LinearAlgebra
julia> LinearAlgebra.BLAS.set_num_threads(1)
julia> @btime mm($a,$b);
827.518 ms (18 allocations: 61.04 MiB)
Julia already faster
If you don’t correct for major dimension orientation, these comparisons won’t be very meaningful.
➜ julia -q ~
julia> using BenchmarkTools, Tullio, LinearAlgebra
julia> BLAS.set_num_threads(1)
julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
2.304 s (2 allocations: 30.52 MiB)
You want: 2000, 400, 2
, also, Tullio doesn’t use BLAS
➜ jupyter console ~
Jupyter console 6.4.0
Python 3.7.10 (default, Apr 27 2021, 08:49:44)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.25.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax.numpy as jnp
...: from jax.config import config
...: config.update("jax_enable_x64", False)
...:
...: import os
...: os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
...:
...: a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
...: b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
...:
...: print(a.dtype) # int32
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
int32
In [2]: %%timeit
...: c = jnp.matmul(a,b)
...:
...:
395 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I am not sure why your Jax is slower
This is getting confusing.
Changing the number of BLAS threads should change the performance of the mm
function, if you are comparing that to Tullio, it does not affect Tullio.
In order to compare single-threaded performance, you should
For multi-threading, do the same, but set up each to have the same number of threads as you have physical cores.
And, very importantly: change orientation of your arrays for Julia, since Julia is column major.
it maybe due to CPU, can you check top
and run it in a loop (like what I did) to make sure Jax is indeed using single thread? the fact that your single thread vs. multi only changed 5% is weird, on my end, I see 1s → 40ms difference with 48 cores.
And you really should use the correct orientation (i.e. memory layout) to be fair.
I see your point. Seems like Jax is using multi-threading on my Mac. Same code and flags give the result you shared on Linux. This might be a bug in Jax on Mac.
Updated the post, take a look. Thanks for the help
@sivakon I apologize for repeating myself, but why do you not fix the major axis orientation issue? Isn’t it correct that that jax (like numpy) uses row-major arrays? In that case these results are not comparable. Is there something I’m missing, since you’re not addressing this concern?
In the benchmark I posted further up, changing orientation leads to a 7x speedup on my computer.
so the latest result suggested that Julia is not slower than Jax even when we’re using the awkward memory layout because col-major vs. row-major? good to know, because Jax is Numpy re-write with some TF runtime and LLVM backend that compiles kernel on its own
also you really should fix col-major vs. row-major to make it apple to apple
using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
# 2.212 s (2 allocations: 30.52 MiB)
using Tullio, BenchmarkTools
a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
# 1.854 s (2 allocations: 30.52 MiB)
Is this what you mean? I don’t see any speed-up here.
That is strange. Here are my results:
BTW, did you use
using Tullio, LoopVectorization
?
Single-threaded:
jl> Threads.nthreads()
1
jl> using Tullio, LoopVectorization
jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
713.923 ms (2 allocations: 30.52 MiB)
jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
133.608 ms (2 allocations: 30.52 MiB)
8 threads:
jl> Threads.nthreads()
8
jl> using Tullio, LoopVectorization
jl> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
jl> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
jl> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
154.384 ms (117 allocations: 30.52 MiB)
jl> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
jl> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
jl> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
23.660 ms (117 allocations: 30.52 MiB)
Mystery solved. Using just Tullio without LoopVectorization yields times around 1.8s. if I have
using Tullio, LoopVectorization
it gives me a 2.5x speedup with row-major orientation, but a 13.5x speedup for colum major.
But anyway, I was just getting a bit exasperated that numerous requests to fix the orientation issue were not acknowledged or noticed.
Can confirm this:
julia> using Tullio, LoopVectorization, BenchmarkTools
julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
1.415 s (2 allocations: 30.52 MiB)
julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
julia> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
251.543 ms (2 allocations: 30.52 MiB)
Without LoopVectorization it takes 3 - 4 seconds with the latter case being only about 20% faster.
I wanted to fix my issue first, have apples to apples comparison when comparing speed between Julia and Jax, and later optimize my Julia code.
julia> using Tullio, LoopVectorization, BenchmarkTools
julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
julia> @btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];
682.080 ms (2 allocations: 30.52 MiB)
julia> a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
julia> b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
julia> @btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i];
126.235 ms (2 allocations: 30.52 MiB)
My code is hella fast now, thanks.