TL;DR How close can we get to 4x speedup if we move from Float16 to Int4 in mat-vec multiply?
Hi everyone,
It’s challenge time! I’m calling on the Julia community to show Python bitsandbytes (recent release with 2-4x speed-up thanks to Int4 and Llama.cpp Quantization Stats how fast we can perform an Int4 matrix-vector multiplication.
Do you have the mettle to leverage Metal.jl or other accelerators to push the boundaries of speed? There are special imaginary points up for grabs for the quickest implementation!
Time to show what Julia can do. Let the games begin!
Yes, this is a joke – I’m sure we CANNOT simply beat the hand-crafted kernels they use. I’m just curious what can be done with bitpacking and tricks available in Julia…
Test case:
- A matrix size that corresponds to an MLP layer in the Llama model (4096x11008)
- Float16 values that can be directly saved in Int4 without any loss (to not worry about quantization tricks)
Setup code:
using Pkg;
Pkg.activate(".");
Pkg.add(["BenchmarkTools", "LinearAlgebra"]);
using BenchmarkTools, LinearAlgebra
# Generate data (but accumulate results in Float32!)
mat_size = (4096, 11008)
m = rand(0:15, mat_size) .|> Float16
v = rand(Float16, mat_size[2]) .* Float16(1e-3) # a to not overflow
# Baseline implementation (keep in F16)
function mygemv1!(output::AbstractVector{<:Real}, mat::AbstractMatrix{<:Real}, vect::AbstractVector{<:Real})
@inbounds @simd for j in axes(mat, 2)
for i in axes(mat, 1)
output[i] += mat[i, j] * vect[j]
end
end
return output
end
# Timings
exp_output = zeros(Float32, mat_size[1])
@time mul!(output, m, v)
# 0.004544 seconds
t0 = @belapsed mul!(output, $m, $v) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 0.003590833
output = zeros(Float32, mat_size[1])
@time mygemv1!(output, m, v);
# 0.005604 seconds
@assert exp_output == output
t1 = @belapsed mygemv1!(output, $m, $v) setup = (output = zeros(Float32, mat_size[1])) evals = 1
# 0.003589542
# Speed-up
t0/t1 # 1x, ie, unsurprisingly, none! :)
(Helpful) Resources:
- Python bitandbytes library releases support for 4-bit inference with 2-4x speedup over FP16 link
- Llama.cpp implementation for Int4 RtN for Metal link
- PR for the new Int4 kernel for Metal in Llama.cpp (with prebuffering) link
- Julia issue for the arbitrary bit-width support link
- StackOverflow on bitpacking performance in Julia vs Python link
- GPTQ Paper showing >3-4.5x speed up with 3-bit inference link
- Julia Issue about support for sub-byte integers link
PS: Most of the above implementations accumulate in Float32, so use that if helpful. Operations that can be done upfront (on load) shouldn’t be counted (eg, transpose etc)
EDIT: updated the reference to accumulate in Float32 and with correct range.