Challenge: Can you beat Python and C++ in Int4 Matrix-Vector Multiply Op?

TL;DR How close can we get to 4x speedup if we move from Float16 to Int4 in mat-vec multiply?

Hi everyone,

It’s challenge time! I’m calling on the Julia community to show Python bitsandbytes (recent release with 2-4x speed-up thanks to Int4 and Llama.cpp Quantization Stats how fast we can perform an Int4 matrix-vector multiplication.

Do you have the mettle to leverage Metal.jl or other accelerators to push the boundaries of speed? There are special imaginary points up for grabs for the quickest implementation!

Time to show what Julia can do. Let the games begin!
Yes, this is a joke – I’m sure we CANNOT simply beat the hand-crafted kernels they use. I’m just curious what can be done with bitpacking and tricks available in Julia…

Test case:

  • A matrix size that corresponds to an MLP layer in the Llama model (4096x11008)
  • Float16 values that can be directly saved in Int4 without any loss (to not worry about quantization tricks)

Setup code:

using Pkg;
Pkg.activate(".");
Pkg.add(["BenchmarkTools", "LinearAlgebra"]);
using BenchmarkTools, LinearAlgebra

# Generate data (but accumulate results in Float32!)
mat_size = (4096, 11008)
m = rand(0:15, mat_size) .|> Float16
v = rand(Float16, mat_size[2]) .* Float16(1e-3) # a to not overflow

# Baseline implementation (keep in F16)
function mygemv1!(output::AbstractVector{<:Real}, mat::AbstractMatrix{<:Real}, vect::AbstractVector{<:Real})
    @inbounds @simd for j in axes(mat, 2)
        for i in axes(mat, 1)
            output[i] += mat[i, j] * vect[j]
        end
    end
    return output
end

# Timings
exp_output = zeros(Float32, mat_size[1])
@time mul!(output, m, v)
# 0.004544 seconds
t0 = @belapsed mul!(output, $m, $v) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 0.003590833

output = zeros(Float32, mat_size[1])
@time mygemv1!(output, m, v);
# 0.005604 seconds
@assert exp_output == output
t1 = @belapsed mygemv1!(output, $m, $v) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 0.003589542

# Speed-up
t0/t1 # 1x, ie, unsurprisingly, none! :)

(Helpful) Resources:

  • Python bitandbytes library releases support for 4-bit inference with 2-4x speedup over FP16 link
  • Llama.cpp implementation for Int4 RtN for Metal link
  • PR for the new Int4 kernel for Metal in Llama.cpp (with prebuffering) link
  • Julia issue for the arbitrary bit-width support link
  • StackOverflow on bitpacking performance in Julia vs Python link
  • GPTQ Paper showing >3-4.5x speed up with 3-bit inference link
  • Julia Issue about support for sub-byte integers link

PS: Most of the above implementations accumulate in Float32, so use that if helpful. Operations that can be done upfront (on load) shouldn’t be counted (eg, transpose etc)

EDIT: updated the reference to accumulate in Float32 and with correct range.

1 Like

I tried to add a baseline with Metal.jl in F16, but I’m doing something wrong (correctness check fails)…

Edit: baseline in Metal.jl updated and it’s correct now, but it still doesn’t leverage Int4 storage

My attempt:

# Base case with acceleration
using Metal
mm = MtlArray(m)
mv = MtlArray(v)

output = Metal.zeros(Float32,mat_size[1])
@time mul!(output,mm, mv) # we can just call a method of mul!
# 0.000157 seconds (23 allocations: 448 bytes)

@assert Array(output) ≈ exp_output

mt0 = @belapsed mul!(output,$mm, $mv) setup = (output = Metal.zeros(Float32, mat_size[1])) evals = 1
# 1.1125e-5

t0/mt0 
# 323x speed up
function mygemmavx!(C, A, B)
   @turbo for m ∈ axes(A,1), n ∈ axes(B,2)
       Cmn = zero(eltype(C))
       for k ∈ axes(A,2)
           Cmn += A[m,k] * B[k,n]
       end
       C[m,n] = Cmn
   end
end

Try something like this?

2 Likes

Yea your definition of v is messed up somehow. Int.(m) * Float64.(v) gives correct results. Not sure what is really going on though.

EDIT: Or actually, something about m * v is messed up, since the fancier methods seem to work fine?

Thank you for the reply!

LoopVectorization.jl is cool, but based on my understanding, it cannot use the fact that we know that we could represent the values as In4 and hence fit 4 times as many in the cache line. Or am I wrong?

The reference implementation is slightly slower than base case on the above problem:

using LoopVectorization
function mygemmavx!(output::AbstractVector{Float32}, mat::AbstractMatrix{<:Real}, vect::AbstractVector{<:Real})
    @turbo for j in axes(mat, 2)
        for i in axes(mat, 1)
            output[i] += mat[i, j] * vect[j]
        end
    end
end
output = zeros(Float32, mat_size[1])
@time mygemmavx!(output, m, v);
# 0.002024 seconds
@assert exp_output ≈ output
t2 = @belapsed mygemmavx!(output, $m, $v) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 0.010778042

# Speed-up
t0 /t2 # 0.3x (so slight slow down)

Yea your definition of v is messed up somehow. Int.(m) * Float64.(v) gives correct results. Not sure what is really going on though.

As for Metal.jl, it was correct all along.
My reference wasn’t because the Float16 accumulation hits some precision/truncation stuff.
It just handled the output in Float16 much more gracefully (probably because it works in F32 and then converts to F16).

I’ve updated the code above to accumulate results in F32, so it all passes now.

EDIT: I’m working on writing a Metal.jl kernel that can leverage the Int4, but it’s painful :smiley: So far, I haven’t been able to get anywhere near the mul! performance with for loops (without adding the bitpacking). My kernel is 500x slower…

Writing a fast matmul kernel is hard… See for example GitHub - JuliaGPU/GemmKernels.jl: Flexible and performant GEMM kernels in Julia, which currently reaches 50-80% of CUBLAS. Although I haven’t benchmarked MPS (the library that powers Metal.jl’s mul!), I wouldn’t expect it to be easy to write a kernel that performs similarly.

You could convert from Int4, but are there any special instructions involved?

Results are going to vary by machine.
For me, mul! took 25ms, mygemv1! 7ms, and mygemmavx! 9ms.

I’m using LinuxPerf to get a summary of running them 1000 times:

julia> @pstats "cpu-cycles,(instructions,branch-instructions,branch-misses),(task-clock,context-switches,cpu-migrations,page-faults),(L1-dcache-load-misses,L1-dcache-loads,L1-icache-load-misses),(dTLB-load-misses,dTLB-loads),(iTLB-load-misses,iTLB-loads)" begin
       foreachf!(mygemv1!, 1000, output, m, v)
       end
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
╶ cpu-cycles               2.65e+10   60.0%  #  3.7 cycles per ns
┌ instructions             3.14e+10   60.0%  #  1.2 insns per cycle
│ branch-instructions      2.88e+09   60.0%  #  9.2% of insns
└ branch-misses            1.10e+07   60.0%  #  0.4% of branch insns
┌ task-clock               7.20e+09  100.0%  #  7.2 s
│ context-switches         0.00e+00  100.0%
│ cpu-migrations           0.00e+00  100.0%
└ page-faults              0.00e+00  100.0%
┌ L1-dcache-load-misses    1.42e+09   20.0%  # 24.9% of dcache loads
│ L1-dcache-loads          5.68e+09   20.0%
└ L1-icache-load-misses    1.95e+05   20.0%
┌ dTLB-load-misses         7.00e+02   20.0%  #  0.0% of dTLB loads
└ dTLB-loads               5.68e+09   20.0%
┌ iTLB-load-misses         1.53e+02   40.0%  # 70.9% of iTLB loads
└ iTLB-loads               2.15e+02   40.0%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

julia> @pstats "cpu-cycles,(instructions,branch-instructions,branch-misses),(task-clock,context-switches,cpu-migrations,page-faults),(L1-dcache-load-misses,L1-dcache-loads,L1-icache-load-misses),(dTLB-load-misses,dTLB-loads),(iTLB-load-misses,iTLB-loads)" begin
       foreachf!(mygemmavx!, 1000, output, m, v)
       end
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
╶ cpu-cycles               3.40e+10   60.0%  #  3.7 cycles per ns
┌ instructions             8.46e+09   60.0%  #  0.2 insns per cycle
│ branch-instructions      3.52e+08   60.0%  #  4.2% of insns
└ branch-misses            4.47e+04   60.0%  #  0.0% of branch insns
┌ task-clock               9.24e+09  100.0%  #  9.2 s
│ context-switches         0.00e+00  100.0%
│ cpu-migrations           0.00e+00  100.0%
└ page-faults              0.00e+00  100.0%
┌ L1-dcache-load-misses    1.41e+09   20.0%  # 44.6% of dcache loads
│ L1-dcache-loads          3.17e+09   20.0%
└ L1-icache-load-misses    2.71e+05   20.0%
┌ dTLB-load-misses         1.01e+04   20.0%  #  0.0% of dTLB loads
└ dTLB-loads               3.17e+09   20.0%
┌ iTLB-load-misses         2.62e+02   40.0%  # 66.5% of iTLB loads
└ iTLB-loads               3.95e+02   40.0%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

julia> 3.14e10/8.46e9
3.7115839243498816

Note that the @turbo version required 3.7 fewer instructions, but ran at 1/6 the instructions/clock cycle, meaning it still took more time.
It suffered from 44.6% L1d-cache load misses, vs 24.9% for mygemv1!.

@turbo blocks m’s rows into chunks, sweeping across all columns for each of these chunks.
You could try blocking m’s columns and y.
Certain blocking patterns might improve performance, e.g. by allowing you to use more of each page of M (less TLB pressure?), and maybe get better use of the hardware prefetchers which are obviously failing here, but seem to be working a little better for mygemv1!'s access pattern.

You could convert from Int4, but are there any special instructions involved?

No, the hope is that you reduce drastically the memory requirements and get more bandwidth. I unwrap it manually (see below).


I’ve tried the Int4 quantized implementation with @turbo, but something is going wrong… I get wrong results and from time to time a segfault.
Any idea what I’m doing wrong? (see mul_q4_avx! below)

Setup

abstract type AbstractQuantization end
struct Int4Naive <: AbstractQuantization end

"Pack two Int4 numbers into one Int8 number (assume perfect input matching the type and block dimensions)"
function pack(::Int4Naive, a::Int, b::Int)
    @assert 0 <= a <= 15 "a must be between 0 and 15 (provided: $a)"
    @assert 0 <= b <= 15 "b must be between 0 and 15 (provided: $b)"
    return UInt8((a << 4) + b)
end
pack(q::Int4Naive, a::Float16, b::Float16) = pack(q, Int(a), Int(b))
pack(Int4Naive(), Float16(2), Float16(3))

# Unpack an UInt8 number into two Int4 numbers
function unpack(::Int4Naive, x::UInt8)
    return x >> 4, x & 0xF
end

# quick check
let q = Int4Naive()
    @assert pack(q, 2, 3) == 0x23
    @assert unpack(q, 0x23) == (2, 3)
    @assert pack(q, Float16(2), Float16(3)) == 0x23
end

# find positions in the original data
origin_col_idx(i, j, NBLOCKS, BLOCKSIZE=32) = (mod1(j, NBLOCKS) - 1) * BLOCKSIZE + i
origin_row_idx(j, NBLOCKS) = fld1(j, NBLOCKS)

# dispatch on a matrix
function pack(q::Int4Naive, m::Matrix{Float16}, BLOCKSIZE=32)
    HALFBLOCK = BLOCKSIZE ÷ 2 # 16
    mat_size = size(m)
    NCOLS = mat_size[2]
    NBLOCKS = NCOLS ÷ BLOCKSIZE
    qm = Matrix{UInt8}(undef, BLOCKSIZE ÷ 2, NBLOCKS * mat_size[1])
    for j in axes(qm, 2)
        col_offset =
            row_idx = origin_row_idx(j, NBLOCKS)
        for i in axes(qm, 1)
            col_idx = origin_col_idx(i, j, NBLOCKS)
            qm[i, j] = pack(q, m[row_idx, col_idx], m[row_idx, col_idx+HALFBLOCK])
        end
    end
    qm
end

# check implementation
BLOCKSIZE = 32
Q = Int4Naive()
m = rand(0:15, 1024, 4096) .|> Float16
qm = pack(Q, m, BLOCKSIZE)
NBLOCKS = size(m, 2) ÷ BLOCKSIZE
for i in axes(qm, 1), j in axes(qm, 2)
    x = qm[i, j]
    x1 = m[origin_row_idx(j, NBLOCKS), origin_col_idx(i, j, NBLOCKS)]
    x2 = m[origin_row_idx(j, NBLOCKS), origin_col_idx(i, j, NBLOCKS)+16]
    @assert x == pack(Q, x1, x2)
end

Standard loop

BLOCKSIZE = 32
Q = Int4Naive()
mat_size = (4096, 11008)
m = rand(0:15, mat_size) .|> Float16
v = rand(Float32, mat_size[2]) .* Float32(1e-3)
output = zeros(Float32, mat_size[1])
exp_output = m * v
qm = pack(Q, m, BLOCKSIZE)

function mul_q4!(C::Vector{Float32}, A::Matrix{UInt8}, B::Vector{Float32}, BLOCKSIZE=32)
    NBLOCKS = length(B) ÷ BLOCKSIZE
    for j ∈ axes(A, 2)
        Cx = zero(eltype(C))
        ## offset with blocksize of 32 and 2blocks per vector
        for i ∈ axes(A, 1)
            Cx += A[i, j] >> 4 * B[origin_col_idx(i, j, NBLOCKS)] + A[i, j] & 0xF * B[origin_col_idx(i, j, NBLOCKS)+16]
        end
        C[origin_row_idx(j, NBLOCKS)] += Cx
    end
end

output = zeros(Float32, mat_size[1])
mul_q4!(output, qm, v, BLOCKSIZE)
@assert exp_output ≈ output
@btime mul_q4!(output, $qm, $v, $BLOCKSIZE) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 58.891 ms (0 allocations: 0 bytes)

LoopVectorization

BLOCKSIZE = 32
Q = Int4Naive()
mat_size = (4096, 11008)
m = rand(0:15, mat_size) .|> Float16
v = rand(Float32, mat_size[2]) .* Float32(1e-3)
output = zeros(Float32, mat_size[1])
exp_output = m * v
qm = pack(Q, m, BLOCKSIZE)

# LoopVectorization.@turbo added
function mul_q4_avx!(C::Vector{Float32}, A::Matrix{UInt8}, B::Vector{Float32}, BLOCKSIZE=32)
    NBLOCKS = length(B) ÷ BLOCKSIZE
    @turbo for j ∈ axes(A, 2)
        Cx = zero(eltype(C))
        for i ∈ axes(A, 1)
            col_idx = origin_col_idx(i, j, NBLOCKS)
            Cx += A[i, j] >> 4 * B[col_idx] + A[i, j] & 0xF * B[col_idx+16]
        end
        C[origin_row_idx(j, NBLOCKS)] += Cx
    end
end

output = zeros(Float32, mat_size[1])
mul_q4_avx!(output, qm, v, BLOCKSIZE)
@assert exp_output ≈ output

# @btime mul_q4_avx!(output, $qm, $v, $BLOCKSIZE) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# crashes

I’ve also hacked up a Metal.jl implementation. It’s really nice how I can call any utility function written in Julia (eg, my origin_col_idx).

I’m getting around 25ms on hand-rolled mul and only 6.5ms on the below q4 mul, so a 3.8x speedup.
The problem is that my baseline is still too inefficient for Int4 to be useful.

I suspect I need to figure out a better data access pattern.
I operate on quantized nibbles which are saved nicely in a column and the multiplication vector is read sequentially, so I assumed it would be enough, but clearly not!

Metal.jl kernel (inspired by link)

function mul_q4_metal1!(output::MtlDeviceVector{T}, qm::MtlDeviceArray{UInt8}, x::MtlDeviceVector{T}) where {T}
    tgpig = threadgroup_position_in_grid_1d()
    tiisg = thread_index_in_simdgroup()
    NROWS = size(qm, 1)
    SIMDWIDTH = 32 #threads_per_simdgroup()
    NBLOCKS = 344 #length(x) ÷ BLOCKSIZE
    NCOLS_THREAD = 10 #fld(NBLOCKS,SIMDWIDTH)
    # end of the last group
    thread_col_offset = (tgpig - 1) * NBLOCKS + (tiisg - 1) * NCOLS_THREAD
    # we will run columns +1 to +32, so +1 to know what row we're on
    row_idx = origin_row_idx(thread_col_offset + 1, NBLOCKS)

    # temporary sum (could be replaced by simd_sum function)
    sumf = MtlThreadGroupArray(T, SIMDWIDTH)
    sumf[tiisg] = 0.0f0

    # each thread in a SIMD group deals with 1 block.
    for col in 1:NCOLS_THREAD
        acc = 0.0f0
        col_offset = thread_col_offset + col
        for i in 1:NROWS
            col_idx = origin_col_idx(i, col_offset, NBLOCKS)
            acc += qm[i, col_offset] >> 4 * x[col_idx] + qm[i, col_offset] & 0xF * x[col_idx+16]
        end
        sumf[tiisg] += acc
    end
    # any blocks left over?
    if NBLOCKS % SIMDWIDTH > 0
        ## finish the odd blocks
        acc = 0.0f0
        col_offset = (tgpig - 1) * NBLOCKS + SIMDWIDTH * NCOLS_THREAD + tiisg
        for i in 1:NROWS
            col_idx = origin_col_idx(i, col_offset, NBLOCKS)
            acc += qm[i, col_offset] >> 4 * x[col_idx] + qm[i, col_offset] & 0xF * x[col_idx+16]
        end
        if tiisg <= NBLOCKS % SIMDWIDTH
            sumf[tiisg] += acc
        end
    end
    ## save results
    # ideally, we would use simd_sum here, but it's not implemented yet.
    if tiisg == 1
        all_sum = 0.0f0
        for i in 1:SIMDWIDTH
            all_sum += sumf[i]
        end
        output[row_idx] += all_sum
    end
    return nothing
end

v = rand(Float32, mat_size[2]) .* Float32(1e-3)
m = rand(0:15, mat_size) .|> Float16
qm = pack(Q, m, BLOCKSIZE)

# metal arrays
qmmetal = MtlArray(qm)
x = MtlArray(v)
output = Metal.zeros(Float32, mat_size[1])
threads, groups = 32, length(output) # operates by rows (1 row = 1 group)
kernel = @metal threads = threads groups = groups mul_q4_metal1!(output, qmmetal, x)
@assert m * v ≈ Array(output)

# kernel function working as expected
output = Metal.zeros(Float32, mat_size[1])
@time Metal.@sync kernel(output, qmmetal, x; threads, groups)
@assert m * v ≈ Array(output)

# time it
bench = @btime begin
    Metal.@sync kernel(output, $qmmetal, $x; threads, groups)
end setup = (output = Metal.zeros(Float32, mat_size[1])) evals = 1
# 6.516 ms (237 allocations: 5.13 KiB)

# MPS is waaaay faster
mmetal = MtlArray(m)
@btime $mmetal * $x
# 89.000 μs (324 allocations: 7.58 KiB)

@maleadt I’ve been struggling a bit to apply it. Would you have any example with Metal.jl and/or any introductory blog post explaining the “Layouts” (and what they do/why). If you’re at JuliaCon, we could chat in person :slight_smile:

EDIT: Is it possible to create thread-local arrays in Metal.jl? I haven’t been able to create anything but the MtlThreadGroupArray

For an explanation of GemmKernels.jl, see the associated paper or thesis.

You can create thread-local arrays in Metal using the MtlThreadGroupArray function (just like in CUDA.jl).

1 Like