Not fast enough backward pass on transformer model

I have implemented a basic transformer like model and have a reference version in PyTorch 2.0 (which does a few extra ops like masking but not less work basically)

I am getting comparable performance in the forward pass

  • PyTorch 2.0 compiled → 115ms
  • Flux Model → 103ms (the flux model does not do masking so it has advantage)

However the divergence in the backward pass is much higher:

  • PyTorch 2.0 model → 337ms
  • Flux Model → 544ms (the flux model does not do masking so it should have advantage)

I would like to find the bottleneck in the backward pass and if possible right some custom backward rule to make things go faster.

How to profile different layers for Zygote backward pass?

Here is my model code:

using Flux
using NNlib
using Flux: onehotbatch

function split_attention_join(qkv::DenseArray{T, 3}, nheads::Int) where {T <: AbstractFloat}
    q, k, v = eachslice(reshape(qkv, 3, :, size(qkv)[2:end]...), dims=1)
    o, A = NNlib.dot_product_attention(q, k, v, nheads=nheads)
    return o

struct PAE{M<:AbstractMatrix} # Position Embed

Flux.@functor PAE

PAE((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = PAE(init(out, in))

(pae::PAE)(seq_in::AbstractMatrix{<:Integer}) = view(pae.W, :, 1:size(seq_in, 1))

function MHSelfAttention(n_embed, n_head, dropout=0.1)
    head_size = n_embed ÷ n_head
    c_attn = Dense(n_embed, 3 * n_embed, bias = true)
    c_proj = Dense(n_embed, n_embed, bias = true)
    resid_dropout = Dropout(dropout)
    return Chain(c_attn,
                 x -> split_attention_join(x, n_head),

function TransformerBlock(n_embed, n_head, dropout=0.1)
    attn = MHSelfAttention(n_embed, n_head, dropout)
    norm1 = LayerNorm(n_embed, affine=false)
    mlp = Chain(Dense(n_embed, 4 * n_embed, gelu),
                Dense(4 * n_embed, n_embed),
    norm2 = LayerNorm(n_embed, affine=false)
    return Chain(
        SkipConnection(Chain(norm1, attn), +),
        SkipConnection(Chain(norm2, mlp), +)

function PositionAwareEmbedding2(n_embed, vocab_size, block_size)
    return Parallel(.+,
                    Embedding(vocab_size, n_embed),
                    PAE(block_size => n_embed))


function Decoder(;vocab_size, n_embed, n_head, n_layers, block_size, dropout=0.1)
    emb = PositionAwareEmbedding2(n_embed, vocab_size, block_size)
    blocks = [TransformerBlock(n_embed, n_head, dropout) for _ in 1:n_layers]
    lm_head = Dense(emb.layers[1].weight')
    return Chain(emb, blocks..., lm_head)

Some more code for benchmarking:

using BenchmarkTools, CUDA, Flux, Flux.Losses
# FastMath and CUBLAS TF32 multiplication mode for fair comparison with PyTorch
CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.cublasMath_t(3))
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)

function make_data(;n_embed, batch_size, block_size, device_fn, vocab_size=65)
    xs = rand(1:vocab_size, block_size, batch_size)
    ys = rand(1:vocab_size, block_size, batch_size)
    ys = onehotbatch(ys, 1:vocab_size)
    xs |> device_fn, ys |> device_fn

xs, ys = make_data(n_embed=768, batch_size=8, block_size=1024, device_fn=gpu, vocab_size=50257)

m2 = Decoder(;vocab_size=50257, n_embed=768, n_head=12, n_layers=12, dropout=Float32(0.1), block_size=1024) |> gpu

const x = xs
const y = ys

@benchmark CUDA.@sync m2(x) # Forward Pass
@benchmark CUDA.@sync Flux.gradient(m2 -> logitcrossentropy(m2(x), y), m2) # Backward Pass

Benchmark Results

@benchmark CUDA.@sync m2(x)
BenchmarkTools.Trial: 46 samples with 1 evaluation.
 Range (min … max):  103.405 ms … 118.335 ms  ┊ GC (min … max): 0.00% … 6.66%
 Time  (median):     103.884 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   110.361 ms ±   7.025 ms  ┊ GC (mean ± σ):  3.31% ± 3.28%

   █                                                        ▁    
  ▆█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▃▁▁▁▁▆█▃▅ ▁
  103 ms           Histogram: frequency by time          118 ms <

 Memory estimate: 383.45 KiB, allocs estimate: 7079.

julia> @benchmark CUDA.@sync Flux.gradient(m2 -> logitcrossentropy(m2(x), y), m2)
BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range (min … max):  539.152 ms … 548.887 ms  ┊ GC (min … max): 1.49% … 2.39%
 Time  (median):     544.953 ms               ┊ GC (median):    1.99%
 Time  (mean ± σ):   545.192 ms ±   2.573 ms  ┊ GC (mean ± σ):  2.00% ± 0.22%

  ▁                               ▁ █▁ ▁      ▁▁      ▁       ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁██▁█▁▁▁▁▁▁██▁▁▁▁▁▁█▁▁▁▁▁▁▁█ ▁
  539 ms           Histogram: frequency by time          549 ms <

 Memory estimate: 2.71 MiB, allocs estimate: 35363.

I am aware that @ToucheSir has previously mentioned that Torch and XLA both do some magic in their compilation however I would still like to know where my backward pass is slow and if I can manually patch that part. Thank You!


PyTorch model details: · GitHub
benchmarking script which is modified version of Karapathy’s original script. Please note for fair comparison we do not use BFloat16 or FlashAttention in the PyTorch model. It is old vanilla handwritten attention in PyTorch. nanoGPT/ at master · karpathy/nanoGPT · GitHub


NNlib.dot_product_attention currently only has a “reference implementation” which hasn’t been optimized. This is mostly because it’s brand new and optimized implementations haven’t been added yet. If you’re looking for optimized versions right now, check out Transformers.jl and its underlying library NeuralAttentionLib.jl.

I am checking out NeuralAttentionLib.jl but I am sure that is not the issue because the PyTorch implementation is a custom defined layer exactly in the NNlib.attention style.

The concern here is 4-5x time for zygote backward pass compared to forward pass even when I remove the attention layers.

If the difference isn’t the attention operations, I think a more minimal example would be in order. I would be curious to see the performance for PyTorch eager mode, as that’s closer to Flux semantics. If the difference between eager and compiled is significant, it should be pretty clear why we’re slower too.

RE profiling, any of the popular methods for profiling Julia code should work for Flux models.

1 Like

Okay so I decided to do the benchmark for PyTorch model with eager mode and sad to report that it is about the same range as julia 531 (PyTorch) vs 544 (Julia) ms. So yeah maybe Zygote is not bad but our rules are in terms of not fusing ops on the backward pass? Any suggestions for possible optimization?

I would like to highlight that the diff is forward pass with compiled PyTorch is negligible so it is only in the backward pass that compilation is doing some optimization for them.

Do you think ReverseDiff will help in this case? I remember reading it has something called a compiled tape?

Zygote will unfuse all elementwise broadcasts (pointwise operations in Python library parlance), yes. I’m not sure if ReverseDiff does the same but you could try it.

PyTorch’s compiled mode has dedicated functionality to handle cases like this. To date, I’m not aware of anyone who has explored doing something similar in Julia (whether for AD or not).

1 Like