Hi,
I’m following this tutorial on implementing transformer architectures and have a question on how to best implement a self-attention block.
My current code for a single attention head looks like this:
# Self-attention head
key = Dense(n_embed, head_size, bias=false)
query = Dense(n_embed, head_size, bias=false)
value = Dense(n_embed, head_size, bias=false)
tril_mask = tril(ones(block_size, block_size)) .== 0
function head(x)
C, T, B = size(x)
k = key(x) # (head_size, Token, Batch)
q = query(x) # (head_size, T, B)
v = value(x) # (head_size, T, B)
wts3 = Transformers.batchedmul(q, k, transA=true) ./ sqrt(1f0 * C)
wts3[tril_mask[1:T, 1:T], :] .= -1f10
wts3 = softmax(wts3; dims=2) # size (T, T, B)
out = permutedims(Transformers.batchedmul(wts3, v, transB=true), (2, 1, 3)) #
end
The lower-triangular mask tril
restricts information flow between tokens to preceding tokens. That is, for a 4 token list (T=4) the output looks like
julia> wts3 = Transformers.batchedmul(q, k, transA=true) ./ sqrt(1f0 * C);
julia> wts3
4×4×1 Array{Float32, 3}:
[:, :, 1] =
1.03351 2.32426 3.46095 2.27605
-0.873403 4.64232 1.02749 -3.3062
-0.329188 2.06334 1.62721 0.66499
2.08628 -2.60349 -0.965947 1.51771
julia> wts3[tril_mask[1:T, 1:T], :] .= -1f0;
julia> wts3
4×4×1 Array{Float32, 3}:
[:, :, 1] =
1.03351 -1.0f10 -1.0f10 -1.0f10
-0.873403 4.64232 -1.0f10 -1.0f10
-0.329188 2.06334 1.62721 -1.0f10
2.08628 -2.60349 -0.965947 1.51771
The forward pass works, but because of the array mutation, Zygote needs a Buffer to take the gradient.
Ideally, I’d like to avoid this and looked through Transformers.jl to understand how this library implements it.
Also, comparable pytorch implementations 1 2 use the masked_fill
operation.
Looking through Transformers.jl, self -attention seems is be implemented here
For causal attention, using the mask above, it looks like causal=true
is the right argument. So atten_op_constr = CausalMultiheadQKVAttenOp
in this line.
In a forward call, self-attention block applies attention to the Q,K,V projections here
This resolves here
where the mask in in args
and mxiginf = weighted_sum_mixing
.
Followed by this
Attention calculation is deferred to mixing
where f=mixingf
.
Finally, weighted_sum_mixing
goes back back scaled_matmul
here
I’m digging down in the code, but can’t find the place where the causal mask is actually applied.
Does anybody maybe have a pointer?