Transformers.jl: How to train for masked languge model tasks in Julia?

What am I missing here - seems to be easy enough but I cannot find any documentation on how to do it

1 Like

@chengchingwen

I couldn’t write a detailed example right now, hopefully this could provide a basic idea of how it could be done. There are plenty of ways, here is a simple case that input sentences contain the mask (β€œ[MASK]”). I use the bert tokenizer for quick illustration, but you can other tokenizer or define it yourself (the displayed TextEncoder is slightly different due to my local version, but the difference can be ignored). The MatchTokenization contains the pattern of mask that prevents it from being split. We modify the encode process by inserting a function that computes the masked positions. Then you should be able to use the masked_position as a mask with cross-entropy loss. Everything else is just feeding input to your model and normal Flux model training.

julia> using Transformers, FuncPipelines, TextEncodeBase           
                                                                                                                         
julia> using TextEncodeBase: nested2batch, nestedcall                                                
                         
julia> textenc = HuggingFace.load_tokenizer("bert-base-cased")
TrfTextEncoder(                                                                                                          
β”œβ”€ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_cased_tokenizer, WordPiece(vocab_size = 28996, unk = [UNK],
max_char = 100)), 5 patterns)),                                             
β”œβ”€ vocab = Vocab{String, SizedArray}(size = 28996, unk = [UNK], unki = 101),                                              
β”œβ”€ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]",
 "[PAD]", 512)),
β”œβ”€ annotate = annotate_strings,                                                                      
β”œβ”€ onehot = lookup_first,                   
β”œβ”€ decode = nestedcall(remove_conti_prefix),                                                                             
β”œβ”€ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space,
target.token); target := (target.token)),
└─ process = Pipelines:                                                 
  ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)       
  ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)                                           
  ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]:<type=1> Input[1]:<type=1> [SEP]:<type=1> (Input[2]:<type=
2> [SEP]:<type=2>)...)(target.token)                                                                                    
  ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token)
  ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token)                                  
  ╰─ target[token] := TextEncodeBase.nested2batch(target.token)
  ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment)
  ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment)                     
  ╰─ target := (target.token, target.segment, target.attention_mask)
)                                                                                           

julia> textenc.tokenizer.tokenization
MatchTokenization(WordPieceTokenization(bert_cased_tokenizer, WordPiece(vocab_size = 28996, unk = [UNK], max_char = 100)), 5 patterns)

julia> textenc.tokenizer.tokenization.patterns
5-element Vector{Regex}:
 r"\Q[PAD]\E"
 r"\Q[UNK]\E"
 r"\Q[CLS]\E"
 r"\Q[SEP]\E"
 r"\Q[MASK]\E"

julia> textenc2 = TextEncoders.TransformerTextEncoder(textenc) do e
           e.process[1:5] |> Pipeline{:masked_position}(nested2batch ∘ nestedcall(isequal("[MASK]")), :token) |>
               e.process[6:end-1] |> PipeGet{(:token, :segment, :attention_mask, :masked_position)}()
       end
TrfTextEncoder(
β”œβ”€ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_cased_tokenizer, WordPiece(vocab_size = 28996, unk = [UNK],
max_char = 100)), 5 patterns)),
β”œβ”€ vocab = Vocab{String, SizedArray}(size = 28996, unk = [UNK], unki = 101),
β”œβ”€ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]",
 "[PAD]", 512)),
β”œβ”€ annotate = annotate_strings,
β”œβ”€ onehot = lookup_first,
β”œβ”€ decode = nestedcall(remove_conti_prefix),
β”œβ”€ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space,
target.token); target := (target.token)),
└─ process = Pipelines:
  ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)
  ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)
  ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]:<type=1> Input[1]:<type=1> [SEP]:<type=1> (Input[2]:<type=
2> [SEP]:<type=2>)...)(target.token)
  ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token)
  ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token)
  ╰─ target[masked_position] := (TextEncodeBase.nested2batch ∘ (x->TextEncodeBase.nestedcall(Base.Fix2{typeof(isequal), St
ring}(isequal, "[MASK]"), x)))(target.token)
  ╰─ target[token] := TextEncodeBase.nested2batch(target.token)
  ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment)
  ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment)
  ╰─ target := (target.token, target.segment, target.attention_mask, target.masked_position)
)

julia> samples = ["[MASK] sample senten 1", "sample [MASK] sentence 2", "just put the batch of sentences with [MASK] in a vector"];

julia> input = TextEncoders.encode(textenc2, samples)
(token = Bool[0 0 … 1 1; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0;;; 0 0 … 1 1; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0;;; 0 0 … 0
0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], segment = [1 1 1; 1 1 1; … ; 1 1 1; 1 1 1], attention_mask = NeuralAttentionlib.L
engthMask{1, Vector{Int32}}(Int32[7, 6, 13]), masked_position = Bool[0 0 0; 1 0 0; … ; 0 0 0; 0 0 0])

julia> input.masked_position
13Γ—3 Matrix{Bool}:
 0  0  0
 1  0  0
 0  1  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  1
 0  0  0
 0  0  0
 0  0  0
 0  0  0

julia> TextEncoders.decode(textenc2, input.token)
13Γ—3 Matrix{String}:
 " [CLS]"   " [CLS]"     " [CLS]"
 " [MASK]"  " sample"    " just"
 " sample"  " [MASK]"    " put"
 " sent"    " sentence"  " the"
 "en"       " 2"         " batch"
 " 1"       " [SEP]"     " of"
 " [SEP]"   " [PAD]"     " sentences"
 " [PAD]"   " [PAD]"     " with"
 " [PAD]"   " [PAD]"     " [MASK]"
 " [PAD]"   " [PAD]"     " in"
 " [PAD]"   " [PAD]"     " a"
 " [PAD]"   " [PAD]"     " vector"
 " [PAD]"   " [PAD]"     " [SEP]"

Thank you for your answer, However isn’t bert-base-cased - is a pretrained model ?
I need to train my own model from scratch. than to perform the mask task - I will try to see how to do my own tokenization as you wrote - any pointer to how it can be done ?

I only use the tokenizer from bert-base-cased, no model is involved. The main question is, to what degree do you want to do from scratch? Those widely used subword tokenization methods are statistic-based tokenization, which means you would need to prepare your own corpus and compute some statistics like word frequency. The quality of the tokenization would depend on the quality of the corpus, so it’s common to use or modify the tokenizer of another pretrained model.

However, if you do intend to build your own tokenizer from scratch, you could either use other tools such as huggingface/tokenizer to β€œtrain” the tokenizer and export it, or consider using BytePairEncoding.jl to β€œtrain” your own BPE tokenizer.

Our tokenization is built with TextEncodeBase.jl, You can find examples about how to build a tokenizer in its test. Unfortunately, the tokenization part is mostly undocumented. Let me know if you need further information.