Optimized version in Julia.
(126ms opt. Julia vs. 676ms in Jax - Int32)
(521ms opt. Julia vs. 2.2s in Jax - Int64)
Using tullio Int32
using Tullio, BenchmarkTools, LoopVectorization
# Row major
a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k]; # 682ms
using Tullio, BenchmarkTools, LoopVectorization
# Column major
a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i]; # 126ms
Using tullio Int64
using Tullio, BenchmarkTools, LoopVectorization
# Row major
a = Array(reshape(Int.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k]; # 1.34s
using Tullio, BenchmarkTools, LoopVectorization
# Column major
a = Array(reshape(Int.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i]; # 521ms
Update:
After discussions, I found out that Jax is using multi-threading by default and using Int32 for calculations. In Julia, the defaults are single threaded and Int64. So, the comparison is flawed.
Verdict: Tested on Linux - Intel CPU (16 core), only single-threaded performance. Every piece of code snippets were monitored using htop
, to make sure all the flags work as intended (single-threaded).
Julia - single threaded int64
(int64 is the default in Julia) - 940ms
using LinearAlgebra, BenchmarkTools
LinearAlgebra.BLAS.set_num_threads(1)
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
C = Array{Int, 3}(undef, size(A,1),size(A,2), size(B,3))
for i in axes(A)[1]
mul!(@view(C[i,:,:]), @view(A[i,:,:]), @view(B[i,:,:]))
end
C
end
C = mm(a,b);
@btime mm(a,b) # 940ms
Julia - single threaded int32
- 445ms
using LinearAlgebra, BenchmarkTools
a = reshape(Int32.(1:2*2000*400), 2,2000,400);
b = reshape(Int32.(1:2*2000*400), 2,400,2000);
LinearAlgebra.BLAS.set_num_threads(1)
function mm(A,B)
C = Array{Int32, 3}(undef, size(A,1),size(A,2), size(B,3))
for i in axes(A,1)
@views C[i,:,:] = A[i,:,:] * B[i,:,:]
end
C
end
@btime mm($a,$b); # 445ms
Python Jax - Single threaded int64
- 2.2s
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
from jax import local_device_count
print(local_device_count())
a = jnp.arange(2 * 2000 * 400, dtype = jnp.int64).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000, dtype = jnp.int64).reshape((2, 400, 2000)) + 1
%%timeit
c = jnp.matmul(a,b) # c is int64 - 2.2s
Python Jax - Single threaded int32
(int 32 is the default in Jax) - 676ms
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", False)
import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")
a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
%%timeit
c = jnp.matmul(a,b) # c is int32 - 676ms
Old:
I am doing performance tests. These are some benchmarks in Julia with Jax(Python) as a baseline. I am using Julia 1.7 latest beta, and latest version of Jax on latest Intel Macbook Pro. Threaded performance is really good, how do I speed up single-threaded performance?
# JAX
import jax.numpy as jnp
a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
%%timeit
jnp.matmul(a,b) # 370ms
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
for i in axes(A)[1]
C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
end
C
end
C = mm(a,b);
@btime mm(a,b) # 838ms
using LinearAlgebra
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
for i in axes(A)[1]
mul!(@view(C[i,:,:]), @view(A[i,:,:]), @view(B[i,:,:]))
end
C
end
C = mm(a,b);
@btime mm(a,b) # 823ms
using MKL
using LinearAlgebra
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
for i in axes(A)[1]
C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
end
C
end
@btime mm(a,b) #881ms
# Threads.nthreads() = 4
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
Threads.@threads for i in axes(A)[1]
C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
end
C
end
C = mm(a,b);
@btime mm(a,b) # 472ms