Speed comparison matrix multiplication in Julia

Optimized version in Julia.

(126ms opt. Julia vs. 676ms in Jax - Int32)
(521ms opt. Julia vs. 2.2s in Jax - Int64)

Using tullio Int32

using Tullio, BenchmarkTools, LoopVectorization
# Row major
a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k]; # 682ms


using Tullio, BenchmarkTools, LoopVectorization
# Column major
a = Array(reshape(Int32.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int32.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i]; # 126ms

Using tullio Int64

using Tullio, BenchmarkTools, LoopVectorization
# Row major
a = Array(reshape(Int.(1:2*2000*400), 2,2000,400));
b = Array(reshape(Int.(1:2*2000*400), 2,400,2000));
@btime @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k]; # 1.34s


using Tullio, BenchmarkTools, LoopVectorization
# Column major
a = Array(reshape(Int.(1:2*2000*400), 2000,400,2));
b = Array(reshape(Int.(1:2*2000*400), 400,2000,2));
@btime @tullio c[j, k, i] := $a[j, q, i] * $b[q, k, i]; # 521ms

Update:

After discussions, I found out that Jax is using multi-threading by default and using Int32 for calculations. In Julia, the defaults are single threaded and Int64. So, the comparison is flawed.

Verdict: Tested on Linux - Intel CPU (16 core), only single-threaded performance. Every piece of code snippets were monitored using htop, to make sure all the flags work as intended (single-threaded).

Julia - single threaded int64 (int64 is the default in Julia) - 940ms

using LinearAlgebra, BenchmarkTools
LinearAlgebra.BLAS.set_num_threads(1)
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
    C = Array{Int, 3}(undef, size(A,1),size(A,2), size(B,3))
    for i in axes(A)[1]
        mul!(@view(C[i,:,:]), @view(A[i,:,:]), @view(B[i,:,:]))
    end
    C
end
C = mm(a,b);
@btime mm(a,b) # 940ms

Julia - single threaded int32 - 445ms

using LinearAlgebra, BenchmarkTools
a = reshape(Int32.(1:2*2000*400), 2,2000,400);
b = reshape(Int32.(1:2*2000*400), 2,400,2000);
LinearAlgebra.BLAS.set_num_threads(1)
function mm(A,B)
     C = Array{Int32, 3}(undef, size(A,1),size(A,2), size(B,3))
     for i in axes(A,1)
         @views C[i,:,:] = A[i,:,:] * B[i,:,:]
     end
     C
 end
@btime mm($a,$b); # 445ms

Python Jax - Single threaded int64 - 2.2s

import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")

from jax import local_device_count
print(local_device_count())

a = jnp.arange(2 * 2000 * 400, dtype = jnp.int64).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000, dtype = jnp.int64).reshape((2, 400, 2000)) + 1

%%timeit
c = jnp.matmul(a,b) # c is int64 - 2.2s

Python Jax - Single threaded int32 (int 32 is the default in Jax) - 676ms

import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", False)

import os
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=0")

a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1

%%timeit
c = jnp.matmul(a,b) # c is int32 - 676ms

Old:

I am doing performance tests. These are some benchmarks in Julia with Jax(Python) as a baseline. I am using Julia 1.7 latest beta, and latest version of Jax on latest Intel Macbook Pro. Threaded performance is really good, how do I speed up single-threaded performance?

# JAX
import jax.numpy as jnp
a = jnp.arange(2 * 2000 * 400).reshape((2, 2000, 400)) + 1
b = jnp.arange(2 * 400 * 2000).reshape((2, 400, 2000)) + 1
%%timeit
jnp.matmul(a,b) # 370ms


using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
    C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
    for i in axes(A)[1]
        C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
    end
    C
end
C = mm(a,b);
@btime mm(a,b) # 838ms


using LinearAlgebra
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
    C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
    for i in axes(A)[1]
        mul!(@view(C[i,:,:]), @view(A[i,:,:]), @view(B[i,:,:]))
    end
    C
end
C = mm(a,b);
@btime mm(a,b) # 823ms


using MKL
using LinearAlgebra
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
    C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
    for i in axes(A)[1]
        C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
    end
    C
end
@btime mm(a,b) #881ms


# Threads.nthreads() = 4
using BenchmarkTools
a = reshape(1:2*2000*400, 2,2000,400);
b = reshape(1:2*2000*400, 2,400,2000);
function mm(A,B)
    C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
    Threads.@threads for i in axes(A)[1]
        C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
    end
    C
end
C = mm(a,b);
@btime mm(a,b) # 472ms
1 Like

Several points here:

  • This a is a lazy range object, while I think np.arange is a dense array. The lazy one is cheap and good for many purposes, but probably not for matrix multiplication.
  • Julia’s arrays are column-major, which means that slicing the first dimension would, for dense arrays, mean that the views aren’t continuous memory.
  • These are arrays of integers, which I don’t think BLAS libraries handle.

These effects may cost you an order of magnitude or more:

julia> x = rand(100,100); y = rand(100,100);
julia> @btime $x * $y;
  17.541 ÎŒs (2 allocations: 78.17 KiB)

julia> x = view(rand(2,100,100),1,:,:); y = view(rand(2,100,100),1,:,:);  # awkward views
julia> @btime $x * $y;
  416.583 ÎŒs (8 allocations: 78.41 KiB)

julia> x = rand(1:99, 100,100); y = rand(1:99, 100,100);  # integers
julia> @btime $x * $y;
  384.959 ÎŒs (8 allocations: 78.41 KiB)

If you in fact want floating point numbers, then the operation you write is NNlib.batched_mul, although with the batch dimension last. This should multi-thread over that dim. Apart from which, all the work is done by whichever library you are using.

If you do want integers, then you probably want to look into LoopVectorization-powered solutions: Tullio, or Octavian.

1 Like

See Matrix Multiplication · LoopVectorization.jl

1 Like

I want to say Jax (python in general) installation experience is horrible, pip Jax tells me to also install jaxlib, jaxlib somehow downloads tensorflow runtime and its own LLVM and start compiling using bazel for 10 minutes. Just for me to run your code and verify my hypothesis


again, I want to point out that:

In [2]: a = np.arange(10)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [3]: a
Out[3]: DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

almost all python’s “ML” backend falls to 32-bit number by default:

just by putting calculation into Int32 like Jax, even before correcting for col-major vs. row-major:

julia> @btime mm(a,b);
  1.189 s (18 allocations: 122.07 MiB)

julia> a = reshape(Int32.(1:2*2000*400), 2,2000,400);

julia> b = reshape(Int32.(1:2*2000*400), 2,400,2000);

julia> @btime mm(a,b);
  571.098 ms (18 allocations: 91.55 MiB)
1 Like

Good point about 32 bits. Note also that these numbers are big enough to overflow in Int32, not sure how Jax handles that:

julia> extrema(a[1,:,:] * b[1,:,:])  # Int64
(170346560000, 513193706560000)

julia> extrema((a[1,:,:] .+ 0.0) * (b[1,:,:] .+ 0.0))  # Float64
(1.7034656e11, 5.1319370656e14)

julia> typemax(Int32) + 0.0
2.147483647e9

julia> extrema(Int32.(a[1,:,:]) * Int32.(b[1,:,:]))
(-2147480576, 2147480064)

I reduced the problem by 2, just to check LoopVectorization performance

julia> a = reshape(1:2000*400, 2000,400);

julia> b = reshape(1:2000*400 ,400,2000);

julia> function A_mul_B!(C, A, B)
           @turbo for n ∈ indices((C,B), 2), m ∈ indices((C,A), 1)
               Cmn = zero(eltype(C))
               for k ∈ indices((A,B), (2,1))
                   Cmn += A[m,k] * B[k,n]
               end
               C[m,n] = Cmn
           end
       end

julia> @btime A_mul_B!(C, A, B)
  488.673 ms (0 allocations: 0 bytes)

The time taken is roughly reduced by half (problem is reduced by half). Loop vectorization doesn’t offer any speedup.

I guess your matrices are too large? Anyways, I was more posting this link to show you that others have already compared matrix multiplication against things like Intel MKL etc. So I thought this might interest you. (Not to suggest that adding @turbo will necessarily speed up your case.)

1 Like

Int32 doesn’t work sadly, I get Int overflow because of big num.

julia> a = reshape(Int32.(1:2*2000*400), 2,2000,400);

julia> b = reshape(Int32.(1:2*2000*400), 2,400,2000);

julia> function mm(A,B)
           C = zeros(Int,size(A)[1],size(A)[2], size(B)[3])
           for i in axes(A)[1]
               C[i,:,:] = @view(A[i,:,:]) * @view(B[i,:,:])
           end
           C
       end
mm (generic function with 1 method)

julia>

julia> @btime mm(a,b)
  963.990 ms (18 allocations: 91.55 MiB)
2×2000×2000 Array{Int64, 3}:
[:, :, 1] =
 -1452131840  -1451811840  -1451491840  -1451171840  -1450851840  -1450531840  
  -814051840  -813731840  -813411840  -813091840  -812771840  -812451840
 -1132771040  -1132450240  -1132129440  -1131808640  -1131487840  -1131167040     -493095840  -492775040  -492454240  -492133440  -491812640  -491491840


In [2]: c
Out[2]: 
DeviceArray([[[ -283192760,  -283112560,  -283032360, ...,  -123033360,
                -122953160,  -122872960],
              [ -867542200,  -867302000,  -867061800, ...,  -387862800,
                -387622600,  -387382400],
              [-1451891640, -1451491440, -1451091240, ...,  -652692240,

, dtype=int32

JAX

Sure. I was surprised by Jax performance (it’s approx 15 times faster than numpy), then I started investigating in Julia.

Thanks for your help

tullio et al. don’t work when matrix is too large and needs multi-threading by BLAS anyway, @sivakon single-thread is an illusion because although python/julia are in single thread, the backend for handling BLAS can be multi-thread

1 Like

Tiny thing, but you want to use axes(A, 1) here.

Typically, if you see such a speed up for a “well-known problem” it is much more likely that you’re unawaringly comparing apples and oranges (different precision/data types, multithreading vs no multithreading etc.) than that one thing is just magically faster. (There can be exceptions of course.)

1 Like

Guess I missed that in Jax. Oops.

I am just exploring multiple ways to optimize it, multi threading or not. Threads.@threads is a super easy optimization, that’s why I included it.

Your arrays are large enough to need more than just the kernel provided, which was why I suggested libraries built on top of this. Both Tullio and Octavian handle threading and memory access, and both get me about a factor 4 or so improvement, in Int64. With Int32 to match Jax, ignoring the inaccuracy, they are 10x faster for me. But this may vary a lot based on your CPU.

Note that BLAS * is itself mult-threaded, which you would normally want to turn off if threading the outermost loop yourself. But Julia’s generic matmul, which is what your example uses, is not.

2 Likes

Tullio seems to work:

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime(@tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k])
  173.740 ms (244 allocations: 30.53 MiB)

julia> @btime mm($a,$b);
  837.323 ms (18 allocations: 61.04 MiB)


julia> @tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k];

julia> c == mm(a,b)
true
3 Likes

LoopVectorization is still for small arrays only. Other libraries built on top of it like Octavian and Tullio add large array support.

1 Like

You are correct. I am going to update the post. Julia is a lot faster if we have Int64 vs Int64 comparisons. My post was comparing Int32 (Jax) vs Int64 (Julia)

1 Like

Can you replicate tullio in single-threaded case? I gain no speed up in single threaded

julia> using Tullio

julia> using BenchmarkTools

julia> Threads.nthreads() # 1
1

julia> a = Array(reshape(Int32.(1:2*2000*400), 2,2000,400));

julia> b = Array(reshape(Int32.(1:2*2000*400), 2,400,2000));

julia> @btime(@tullio c[i, j, k] := $a[i, j, q] * $b[i, q, k])
  2.340 s (2 allocations: 30.52 MiB)