I have implemented a basic transformer like model and have a reference version in PyTorch 2.0 (which does a few extra ops like masking but not less work basically)
I am getting comparable performance in the forward pass
- PyTorch 2.0 compiled → 115ms
- Flux Model → 103ms (the flux model does not do masking so it has advantage)
However the divergence in the backward pass is much higher:
- PyTorch 2.0 model → 337ms
- Flux Model → 544ms (the flux model does not do masking so it should have advantage)
I would like to find the bottleneck in the backward pass and if possible right some custom backward rule to make things go faster.
How to profile different layers for Zygote backward pass?
Here is my model code:
using Flux
using NNlib
using Flux: onehotbatch
function split_attention_join(qkv::DenseArray{T, 3}, nheads::Int) where {T <: AbstractFloat}
q, k, v = eachslice(reshape(qkv, 3, :, size(qkv)[2:end]...), dims=1)
o, A = NNlib.dot_product_attention(q, k, v, nheads=nheads)
return o
end
struct PAE{M<:AbstractMatrix} # Position Embed
W::M
end
Flux.@functor PAE
PAE((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = PAE(init(out, in))
(pae::PAE)(seq_in::AbstractMatrix{<:Integer}) = view(pae.W, :, 1:size(seq_in, 1))
function MHSelfAttention(n_embed, n_head, dropout=0.1)
head_size = n_embed ÷ n_head
c_attn = Dense(n_embed, 3 * n_embed, bias = true)
c_proj = Dense(n_embed, n_embed, bias = true)
resid_dropout = Dropout(dropout)
return Chain(c_attn,
x -> split_attention_join(x, n_head),
c_proj,
resid_dropout)
end
function TransformerBlock(n_embed, n_head, dropout=0.1)
attn = MHSelfAttention(n_embed, n_head, dropout)
norm1 = LayerNorm(n_embed, affine=false)
mlp = Chain(Dense(n_embed, 4 * n_embed, gelu),
Dense(4 * n_embed, n_embed),
Dropout(dropout))
norm2 = LayerNorm(n_embed, affine=false)
return Chain(
SkipConnection(Chain(norm1, attn), +),
SkipConnection(Chain(norm2, mlp), +)
)
end
function PositionAwareEmbedding2(n_embed, vocab_size, block_size)
return Parallel(.+,
Embedding(vocab_size, n_embed),
PAE(block_size => n_embed))
end
function Decoder(;vocab_size, n_embed, n_head, n_layers, block_size, dropout=0.1)
emb = PositionAwareEmbedding2(n_embed, vocab_size, block_size)
blocks = [TransformerBlock(n_embed, n_head, dropout) for _ in 1:n_layers]
lm_head = Dense(emb.layers[1].weight')
return Chain(emb, blocks..., lm_head)
end
Some more code for benchmarking:
using BenchmarkTools, CUDA, Flux, Flux.Losses
# FastMath and CUBLAS TF32 multiplication mode for fair comparison with PyTorch
CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.cublasMath_t(3))
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)
function make_data(;n_embed, batch_size, block_size, device_fn, vocab_size=65)
xs = rand(1:vocab_size, block_size, batch_size)
ys = rand(1:vocab_size, block_size, batch_size)
ys = onehotbatch(ys, 1:vocab_size)
xs |> device_fn, ys |> device_fn
end
xs, ys = make_data(n_embed=768, batch_size=8, block_size=1024, device_fn=gpu, vocab_size=50257)
m2 = Decoder(;vocab_size=50257, n_embed=768, n_head=12, n_layers=12, dropout=Float32(0.1), block_size=1024) |> gpu
const x = xs
const y = ys
@benchmark CUDA.@sync m2(x) # Forward Pass
@benchmark CUDA.@sync Flux.gradient(m2 -> logitcrossentropy(m2(x), y), m2) # Backward Pass
Benchmark Results
@benchmark CUDA.@sync m2(x)
BenchmarkTools.Trial: 46 samples with 1 evaluation.
Range (min … max): 103.405 ms … 118.335 ms ┊ GC (min … max): 0.00% … 6.66%
Time (median): 103.884 ms ┊ GC (median): 0.00%
Time (mean ± σ): 110.361 ms ± 7.025 ms ┊ GC (mean ± σ): 3.31% ± 3.28%
█ ▁
▆█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▃▁▁▁▁▆█▃▅ ▁
103 ms Histogram: frequency by time 118 ms <
Memory estimate: 383.45 KiB, allocs estimate: 7079.
julia> @benchmark CUDA.@sync Flux.gradient(m2 -> logitcrossentropy(m2(x), y), m2)
BenchmarkTools.Trial: 10 samples with 1 evaluation.
Range (min … max): 539.152 ms … 548.887 ms ┊ GC (min … max): 1.49% … 2.39%
Time (median): 544.953 ms ┊ GC (median): 1.99%
Time (mean ± σ): 545.192 ms ± 2.573 ms ┊ GC (mean ± σ): 2.00% ± 0.22%
▁ ▁ █▁ ▁ ▁▁ ▁ ▁
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁██▁█▁▁▁▁▁▁██▁▁▁▁▁▁█▁▁▁▁▁▁▁█ ▁
539 ms Histogram: frequency by time 549 ms <
Memory estimate: 2.71 MiB, allocs estimate: 35363.
I am aware that @ToucheSir has previously mentioned that Torch and XLA both do some magic in their compilation however I would still like to know where my backward pass is slow and if I can manually patch that part. Thank You!
PS:
PyTorch model details: bench.py · GitHub
benchmarking script which is modified version of Karapathy’s original script. Please note for fair comparison we do not use BFloat16 or FlashAttention in the PyTorch model. It is old vanilla handwritten attention in PyTorch. nanoGPT/model.py at master · karpathy/nanoGPT · GitHub