Flux Transformer Out of Memory

I started off with PyTorch, but after hearing about Julia Flux I decided to try it out to reap the promised elegance and ease of use. I ported my transformer implementation to Julia and verified that the outputs match those of my PyTorch code. However, the model runs more slowly than the PyTorch version and I cannot run my it with a batch size of 64, which PyTorch seems to handle easily, without getting a GPU out of memory error. The self-attention code in Julia is nearly identical to the PyTorch code so I can only conclude that Julia’s or Zygote’s handling of matrix operations is significantly subpar to that of PyTorch. Is there some easy way to make the code work well in Julia without writing low level code? Otherwise I’ll just have to stick with PyTorch as it handles everything I throw at it.

function (sa::MultiHeadSelfAttention)(x)
    C, T, B = size(x) # (n_embd, T, B)
    
    # slice to get queries, keys, and values
    qkv = sa.c_attn(x) # (3 * n_embd, T, B)
    #q, k, v = (view(qkv, i:i+sa.n_embd-1, :, :) for i in 1:sa.n_embd:2*sa.n_embd+1) # (n_embd, T, B)
    q = view(qkv, 1:sa.n_embd, :, :) # (n_embd, T, B)
    k = view(qkv, 1+sa.n_embd:2*sa.n_embd, :, :) # (n_embd, T, B)
    v = view(qkv, 1+2*sa.n_embd:3*sa.n_embd, :, :) # (n_embd, T, B)

    # slice for each head
    head_size = sa.n_embd Ă· sa.n_heads
    #q, k, v = map(x->permutedims(reshape(x, head_size, sa.n_heads, T, B), [1, 3, 2, 4]), (q, k, v) ) # (head_size, T, n_head, B)
    q = permutedims(reshape(q, head_size, sa.n_heads, T, B), [1, 3, 2, 4])
    k = permutedims(reshape(k, head_size, sa.n_heads, T, B), [1, 3, 2, 4])
    v = permutedims(reshape(v, head_size, sa.n_heads, T, B), [1, 3, 2, 4])

    # compute the logits
    k = permutedims(k, [2, 1, 3, 4]) # (T, head_size, n_head B)
    scale = convert(eltype(sa.mask), 1 / sqrt(head_size)) # convert for performance
    wei = scale * batched_mul(k, q) # (T, T, n_head, B)

    # add a triangular matrix filled with -inf on the bottom to prevent attention to the future
    masked_wei = wei .+ view(sa.mask, 1:T, 1:T)

    # add up the values
    coes = softmax(masked_wei)
    out = batched_mul(v, coes) # (head_size, T, n_head, B)

    # concatenate the heads
    out = permutedims(out, [1, 3, 2, 4]) # (head_size, n_head, T, B)
    out = reshape(out, sa.n_embd, T, B) # (n_embd, T, B)

    return out
end
# PyTorch implementation
def forward(self, x):
    B, T, C = x.shape
    
    assert C == self.n_embd
    
    # split into queries, keys, and values
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, n_emb)
    # split q, k, and v for each head
    head_size = C // self.n_head
    q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
    k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
    
    # compute attention scores
    wei = ( q @ k.transpose(-2, -1) ) * k.shape[-1]**-0.5
    wei = wei.masked_fill(self.tril[:, :, :T, :T] == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    wei = self.attn_dropout(wei)
    out = wei @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    
    # aggregate head outputs
    out = out.transpose(1, 2).contiguous().view(B, T, C)
    return out
1 Like

It’s hard to comment without seeing the definition of each layer, but just off the bat I’m not sure the definitions are comparable. For example, taking 3 views is not comparable to .split in PyTorch. Likewise, I believe permutedims copies the source array whereas PyTorch’s .transpose may not.

If you want to see a more optimized version of MHA in Julia, see the implementation of the brand new Flux layer and related supporting code.

1 Like

I checked out the MHA code in the links and it seems essentially equivalent to my code except that my code stores the query, key, and value matrices in one matrix instead of storing them separately. Perhaps the attention code isn’t the bottleneck? I’ll link my complete code. I am new to Julia so it would be great to know if I am doing something particularly inefficiently or if I am using poor style.

Is Zygote’s requirement that only pure functions can differentiated a cause for worry? After Zygote generates the differentiated versions is the compiler smart enough to convert some of the code into in-place operations?

That is important, because a function like .split can be more memory efficient when differentiated than multiple indexing/view operations. Note how the NNlib code I linked specifically avoids doing this. See also my comment about permutedims vs .transpose: not having to copy saves a lot of memory!

Not in the way you think, because Zygote doesn’t have such a requirement. The only in-place operations which are forbidden are on arrays, which can be limiting but means many other “impure” language features are still allowed. There are definitely cases where Zygote is not as memory efficient as PyTorch’s AD (and I think some of them are at play here), but those apply even if you use zero in-place operations with both libraries.

I altered the attention layer to use separate q, k, anv v matrices and saw no significant improvement in the memory footprint. The culprit of the memory issues must be the copies that permutedims and others are doing. I am confused as to how Flux could have overlooked crucial functions such as a transpose that doesn’t make copies. Is there some limitation that precludes the use of functions that PyTorch users take for granted in the Flux ecosystem or has no one gotten around to implementing them?

As it stands, PyTorch provides efficient tensor operations that allow for high level implementations without sacrificing performance or incurring a too high memory footprint. It seems that Flux isn’t there yet. I commend the Flux team for what they have accomplished and I would jump at Flux the moment it ever becomes competitive with PyTorch. Is Flux on track to achieve this, or does Flux make trade-offs that will hinder it from ever having the raw performance of PyTorch?

No and no. To resolve your confusion, we should clarify where permutedims is defined. Unlike with PyTorch where the framework defines and provides implementations for each operator, many functions you call in Flux models are defined elsewhere. In the case of permutedims, that elsewhere is actually the Julia standard library (Base). A good analogy would be if you could use numpy.transpose — NumPy v1.26 Manual instead of torch.Tensor.transpose in PyTorch and have it just work.

However, that just pushes the question upstream: why does the stdlib permutedims copy? I don’t know the correct answer, but I’ve asked around for a historical record of this decision and will update this thread if/when I receive one. The more relevant answer is that you can perform a non-copying dim transpose by using PermutedDimsArray. This is essentially what PyTorch is doing under the hood, and is also part of the Base stdlib. The biggest caveat of PermutedDimsArray is that not all user-defined functions may understand the wrapper and take the most efficient codepath. Relevant to this thread’s example, note how the last post on the second linked thread mentions NNlib’s batched_mul routines. What’s that in the docs page? PermutedDimsArray. This is precisely why I linked the definition of dot_product_attention in NNlib: it shows you how to use these tools to write an efficient attention operation which works much like the PyTorch one does.

All that said, I think your follow-up question is more or less addressed:

  1. Most of the “Flux” performance here is actually Base Julia performance and should be discussed accordingly.
  2. Direct translations are often not apples-to-apples and knowing what the idiomatic patterns are in each language (e.g. PermutedDimsArray to avoid copies) can make a big difference.
  3. Because DL frameworks have such large API surface areas and NN models vary greatly, “competitiveness” is always context sensitive. PyTorch definitely gets the most engineering effort towards optimizing its operations, but that doesn’t always translate to a performance win.
1 Like

I tried using PermuteDimsArray instead of permutedims and it doesn’t work for this scenario as the batch dimension must be moved around. Permuting the batch dimension with PermuteDimsArray leads to GPU array indexing operations on the CPU which completely destroys performance.

It doesn’t seem possible to replicate PyTorch’s efficiency in this scenario without resorting to more convoluted algorithms or low level CUDA programming. I’ll just have to use a specific highly engineered API such as PyTorch instead of relying on Julia’s generic operations. The performance vs simplicity trade off here seems to favor PyTorch (in terms of model construction by the users not low level API implementation).

Have you tried GitHub - chengchingwen/NeuralAttentionlib.jl: Reusable functionality for defining custom attention/transformer layers. ? This is the library that powers Transformers.jl

I could just use Transformers.jl but I what I am really trying to do here is determine if I can justify using Julia’s Flux over PyTorch. The scenario of implementing a custom transformer really highlights the strengths and weaknesses of each API. PyTorch makes it so easy to write a transformer using only basic tensor operations that is performant while all the similarly simple implementations in Flux are much less efficient. Why would I use Julia/Flux over PyTorch if I can easily implement custom models in PyTorch whereas achieving similar performance in Flux would require much more engineering on my part or hunting down someone else’s CUDA implementation?

I would still be interested in the exact number of allocations each implementation used. Does the forward function allocate similar number of memories in each implementation? Does PyTorch free the memory more eagerly, or does Zygote simply use more temporal arrays? PyTorch does have a more expressive array view system, but usually to make it run fast it still make a contiguous copy. IMPO the problem here isn’t just split/transpose/permutedims/..., but some memory management issue.

I would kindly disagree. While I am not very knowledgeable about customs in Pytorch ecosystem, it is common in Flux (Lux) that people write their own AD rules to make things performant, because it is very easy. These rules do not have to end up in Flux (or NNLib) because they are specific for the application or their might not be consensus about their usefulness, API, etc.

For a batch size of 32 PyTorch uses about 3% of my GPU’s memory while Julia uses about 80%. After playing around with it I think you are correct that the main issue isn’t with split/transpose/permutedims/.... Even when I changed the attention layer to just return the input the memory usage was still excessive.

I’d be curious why you think a function like NNlib.batched_mul (which is used in both Flux and Transformers.jl for efficient attention operations) is confusing. It’s very much a “basic” array operation at the level of the PyTorch ones you’ve been using. My point this whole time with linking the (pretty readable) NNlib dot_product_attention code as a reference is that no additional engineering or hunting down CUDA-specific implementations should be required on your end.

It is even easier to not write custom AD rules. There’s no point in using Flux in my projects if I have to manually differentiate just to approach PyTorch’s performance.

That’s misleading because CUDA.jl caches memory allocations to speed up subsequent allocations. Unless you’re running out of memory, the number you see on nvidia-smi isn’t representative of how much the model is actually using.

That might be another case that the memory is just not free or been reserved by the memory pool, do you run GC.gc(); CUDA.reclaim() after the computation to make sure julia freeing unsed memory?

I tried freeing the memory pool and the memory usage was still over 80%.

How large is your GPU? My 75M parameters transformers model take ~3GB after training and freeing. I would guess there’re something captured by your REPL or something so the memory cannot be freed.

1 Like

You can also set attribute!(memory_pool(device()), CUDA.MEMPOOL_ATTR_RELEASE_THRESHOLD, UInt(0)). The memory usage would be more accurate.

1 Like

My GPU’s memory is 6 GB. Once all the calculations are done only 6% is used but the backpropagation itself uses around 80%. The attribute! code you provided yields the error that the operation isn’t supported.