I started off with PyTorch, but after hearing about Julia Flux I decided to try it out to reap the promised elegance and ease of use. I ported my transformer implementation to Julia and verified that the outputs match those of my PyTorch code. However, the model runs more slowly than the PyTorch version and I cannot run my it with a batch size of 64, which PyTorch seems to handle easily, without getting a GPU out of memory error. The self-attention code in Julia is nearly identical to the PyTorch code so I can only conclude that Julia’s or Zygote’s handling of matrix operations is significantly subpar to that of PyTorch. Is there some easy way to make the code work well in Julia without writing low level code? Otherwise I’ll just have to stick with PyTorch as it handles everything I throw at it.
function (sa::MultiHeadSelfAttention)(x)
C, T, B = size(x) # (n_embd, T, B)
# slice to get queries, keys, and values
qkv = sa.c_attn(x) # (3 * n_embd, T, B)
#q, k, v = (view(qkv, i:i+sa.n_embd-1, :, :) for i in 1:sa.n_embd:2*sa.n_embd+1) # (n_embd, T, B)
q = view(qkv, 1:sa.n_embd, :, :) # (n_embd, T, B)
k = view(qkv, 1+sa.n_embd:2*sa.n_embd, :, :) # (n_embd, T, B)
v = view(qkv, 1+2*sa.n_embd:3*sa.n_embd, :, :) # (n_embd, T, B)
# slice for each head
head_size = sa.n_embd Ă· sa.n_heads
#q, k, v = map(x->permutedims(reshape(x, head_size, sa.n_heads, T, B), [1, 3, 2, 4]), (q, k, v) ) # (head_size, T, n_head, B)
q = permutedims(reshape(q, head_size, sa.n_heads, T, B), [1, 3, 2, 4])
k = permutedims(reshape(k, head_size, sa.n_heads, T, B), [1, 3, 2, 4])
v = permutedims(reshape(v, head_size, sa.n_heads, T, B), [1, 3, 2, 4])
# compute the logits
k = permutedims(k, [2, 1, 3, 4]) # (T, head_size, n_head B)
scale = convert(eltype(sa.mask), 1 / sqrt(head_size)) # convert for performance
wei = scale * batched_mul(k, q) # (T, T, n_head, B)
# add a triangular matrix filled with -inf on the bottom to prevent attention to the future
masked_wei = wei .+ view(sa.mask, 1:T, 1:T)
# add up the values
coes = softmax(masked_wei)
out = batched_mul(v, coes) # (head_size, T, n_head, B)
# concatenate the heads
out = permutedims(out, [1, 3, 2, 4]) # (head_size, n_head, T, B)
out = reshape(out, sa.n_embd, T, B) # (n_embd, T, B)
return out
end
# PyTorch implementation
def forward(self, x):
B, T, C = x.shape
assert C == self.n_embd
# split into queries, keys, and values
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, n_emb)
# split q, k, and v for each head
head_size = C // self.n_head
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
# compute attention scores
wei = ( q @ k.transpose(-2, -1) ) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:, :, :T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.attn_dropout(wei)
out = wei @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# aggregate head outputs
out = out.transpose(1, 2).contiguous().view(B, T, C)
return out