Using Transformers.jl for "is next sentence"

Currently the bert_model from pretrain"" and model from hgf"" works parallelly. And the api are a little different.

First of all, both pretrain"bert-uncased_L-12_H-768_A-12" and hgf"bert-base-uncased:fornextsentenceprediction" load the entire model, so we will have duplicate model weight here. This is avoidable. If you want to use the model from hgf, then load the wordpiece and tokenizer separately from pretrain"" like this:

model = hgf"bert-base-uncased:fornextsentenceprediction"
wordpiece = pretrain"bert-uncased_L-12_H-768_A-12:wordpiece"
tokenizer = pretrain"bert-uncased_L-12_H-768_A-12:tokenizer"

Second, model from hgf"" only support batched input, so we need to reshape the input into size (sequence length, batch size) (in Julia we choose the batch dimension as the last axis) like this:

token_indices = reshape(token_indices, length(token_indices), 1)
segment_indices = reshape(segment_indices, length(segment_indices), 1)

where the 1 means we only have 1 sentence for this minibatch.

Finally, call model with token_indices, but the segment_indices should be passed as keyword argument (which match the behavior of huggingface/transformer).

result = model(token_indices; token_type_ids=segment_indices)

and the result.logits is the prediction score you want.

The full code:

using Transformers
using Transformers.Basic
using Transformers.Pretrain
using Transformers.HuggingFace

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

model = hgf"bert-base-uncased:fornextsentenceprediction"
wordpiece = pretrain"bert-uncased_L-12_H-768_A-12:wordpiece"
tokenizer = pretrain"bert-uncased_L-12_H-768_A-12:tokenizer"

vocab = Vocabulary(wordpiece)

text1 = "Aesthetic Appreciation and Spanish Art:" |> tokenizer |> wordpiece
text2 = "Insights from Eye-Tracking" |> tokenizer |> wordpiece
formatted_text = ["[CLS]"; text1; "[SEP]"; text2; "[SEP]"]

token_indices = vocab(formatted_text)
segment_indices = [fill(1, length(text1)+2); fill(2, length(text2)+1)]
token_indices = reshape(token_indices, length(token_indices), 1)
segment_indices = reshape(segment_indices, length(segment_indices), 1)

result = model(token_indices; token_type_ids=segment_indices)
1 Like