Implementation of self-attention in Transformers.jl?

I’m following this tutorial on implementing transformer architectures and have a question on how to best implement a self-attention block.

My current code for a single attention head looks like this:

# Self-attention head
key = Dense(n_embed, head_size, bias=false)
query = Dense(n_embed, head_size, bias=false)
value = Dense(n_embed, head_size, bias=false)
tril_mask = tril(ones(block_size, block_size)) .== 0

function head(x)
    C, T, B = size(x)
    k = key(x) # (head_size, Token, Batch)
    q = query(x) # (head_size, T, B)
    v = value(x) # (head_size, T, B)
    wts3 = Transformers.batchedmul(q, k, transA=true) ./ sqrt(1f0 * C)
    wts3[tril_mask[1:T, 1:T], :] .= -1f10
    wts3 = softmax(wts3; dims=2) # size (T, T, B)
    out = permutedims(Transformers.batchedmul(wts3, v, transB=true), (2, 1, 3)) #

The lower-triangular mask tril restricts information flow between tokens to preceding tokens. That is, for a 4 token list (T=4) the output looks like

julia> wts3 = Transformers.batchedmul(q, k, transA=true) ./ sqrt(1f0 * C);

julia> wts3
4×4×1 Array{Float32, 3}:
[:, :, 1] =
  1.03351    2.32426   3.46095    2.27605
 -0.873403   4.64232   1.02749   -3.3062
 -0.329188   2.06334   1.62721    0.66499
  2.08628   -2.60349  -0.965947   1.51771

julia> wts3[tril_mask[1:T, 1:T], :] .= -1f0;

julia> wts3
4×4×1 Array{Float32, 3}:
[:, :, 1] =
  1.03351   -1.0f10   -1.0f10    -1.0f10
 -0.873403   4.64232  -1.0f10    -1.0f10
 -0.329188   2.06334   1.62721   -1.0f10
  2.08628   -2.60349  -0.965947   1.51771

The forward pass works, but because of the array mutation, Zygote needs a Buffer to take the gradient.
Ideally, I’d like to avoid this and looked through Transformers.jl to understand how this library implements it.

Also, comparable pytorch implementations 1 2 use the masked_fill operation.

Looking through Transformers.jl, self -attention seems is be implemented here

For causal attention, using the mask above, it looks like causal=true is the right argument. So atten_op_constr = CausalMultiheadQKVAttenOp in this line.

In a forward call, self-attention block applies attention to the Q,K,V projections here

This resolves here
where the mask in in args and mxiginf = weighted_sum_mixing.

Followed by this

Attention calculation is deferred to mixing
where f=mixingf.

Finally, weighted_sum_mixing goes back back scaled_matmul here

I’m digging down in the code, but can’t find the place where the causal mask is actually applied.
Does anybody maybe have a pointer?


Here is how you can mask without mutation:

make_decoder_mask(block_size) = tril(fill(Float32(-1f8), block_size, block_size), -1)
mask = make_decoder_mask(block_size)

Now after calculating the attention matrix:

A .+ mask

If your mask is constant like for training a decoder you would probably save it as some field of your Self Attention layer and make sure that only relevant params are trainable using something like this:

Flux.trainable(m::MHSelfAttention) = (m.MH_QKV, m.MH_O)

I have a messy implementation here which supports custom masks ones you might want to use for encoder etc. but it should give you an idea MakeMore.jl/model.jl at main · reachtarunhere/MakeMore.jl · GitHub

1 Like

Instead of a Zygote buffer to mutate the array, one can also just add an upper triangular matrix with large negative values. That has the same effect:

wts3 = wts3 .+ triu(ones(eltype(wts3), T, T), 1) .* -1f10

Thanks @reachtarunhere . That solves the implementation part ( I had the same idea after spelling it out for this post).

I’m still interested to understand the implementation in NeuralAttention.jl though…

1 Like

Do checkout NNlib.jl/attention.jl at master · FluxML/NNlib.jl · GitHub too as the attention op is now part of NNlib and has very readable implementation which supports masking, bias, dropout etc.


The CausalMultiheadQKVAttenOp is calling NeuralAttentionlib.multihead_qkv_attention with NeuralAttentionlib.CausalMask, which create a non-allocating broadcastable object indicating the position that would involve in the computation. There’re lots of different kinds of mask in NeuralAttentionlib.jl and the applying part is nothing but some broadcast. The part you are missing is just the masked_score.


Thanks @chengchingwen

What is the syntax for using masks? @chengchingwen
This script uses the syntax in the documentation
but throws an error:

using NeuralAttentionlib

q = ones(Float32, 4)
k = ones(Float32, 4)
v = ones(Float32, 4)

NeuralAttentionlib.naive_qkv_attention(q, k, v; mask=NeuralAttentionlib.:CausalMask)

ERROR: LoadError: MethodError: no method matching naive_qkv_attention(::Vector{Float32}, ::Vector{Float32}, ::Vector{Float32}; mask=NeuralAttentionlib.CausalMask)
Closest candidates are:
  naive_qkv_attention(::AbstractArray, ::AbstractArray, ::AbstractArray, ::Any...) at ~/.julia/packages/NeuralAttentionlib/F0XsF/src/functional/attention.jl:34 got unsupported keyword argument "mask"
  naive_qkv_attention(::typeof(NeuralAttentionlib.score_returning), ::AbstractArray, ::AbstractArray, ::AbstractArray, ::Any...) at ~/.julia/packages/NeuralAttentionlib/F0XsF/src/functional/attention.jl:37 got unsupported keyword argument "mask"
 [1] top-level scope
   @ ~/source/repos/tinygpt/src/test_neuralattn.jl:14
in expression starting at /Users/ralph/source/repos/tinygpt/src/test_neuralattn.jl:14

I’ve also followed the link to the source but that didn’t help.

using NeuralAttentionlib

q = ones(Float32, 10, 7)
k = ones(Float32, 10, 4)
v = ones(Float32, 10, 4)

NeuralAttentionlib.naive_qkv_attention(q, k, v, NeuralAttentionlib.CausalMask())
NeuralAttentionlib.naive_qkv_attention(NeuralAttentionlib.score_returning, q, k, v, NeuralAttentionlib.CausalMask()).attention_score

  1. mask should be an object, not type.
  2. mask is passed as position argument, not keyword.
  3. q/k/v would have the shape of (hidden size, length 1 size, ..., length n size, batch size) so Vector input is treated as single element (which can’t really see any effect of the attention).