Transformers.jl: causal mask on decoder

hi, I am new to Transformers.jl and try to follow the tutorial (Tutorial · Transformers.jl). I wonder where I can find more details about this call
t = decoder_trf(e, m, attention_mask, cross_attention_mask)

In particular, how to modify the above to allow a causal mask to be applied to the decoder input (to avoid peeking ahead). Many thanks!

@chengchingwen
To be more specific, for the lookahead mask, shall I do something along the line of:

t = decoder_trf(e, m, NeuralAttentionlib.CausalMask(), cross_attention_mask)

thanks!

You don’t need to do that manually. The TransformerDecoderBlock constructor create a CausalMultiheadQKVAttenOp for the self attention, which does the causal masking already. The basic functionality of attention_mask in decoder is for putting something like LengthMask for avoiding padding affect the computation.

1 Like

@chengchingwen that’s convenient and works like a charm!

A follow-up question: In the rare case when I dont want to have this mask, or maybe a special mask that’s not triangular, can this be done?

Yes, but you would probably need to call the inner-most constructor with MultiheadQKVAttenOp and pass your own attention mask from the input. You can find NeuralAttentionlib for more kind of masks.