I do use 3D array for ner label. Just wrap them with Basic.Vocabulary
.
I can show you the script that I used to process the conll2003 dataset:
using JSON3
using Arrow
const datainfo = open(JSON3.read, "./datasets/conll2003/dataset_info.json")
const pos_labels = collect(datainfo.features.pos_tags.feature.names)
const chunk_labels = collect(datainfo.features.chunk_tags.feature.names)
const ner_labels = collect(datainfo.features.ner_tags.feature.names)
const pos_vocab = Vocabulary(pos_labels, ".")
const chunk_vocab = Vocabulary(chunk_labels, chunk_labels[1])
const ner_vocab = Vocabulary(ner_labels, ner_labels[1])
const trainset = Arrow.Table("./datasets/conll2003/conll2003-train.arrow")
const devset = Arrow.Table("./datasets/conll2003/conll2003-validation.arrow")
const testset = Arrow.Table("./datasets/conll2003/conll2003-test.arrow")
const train_num = length(trainset.id)
const dev_num = length(devset.id)
const test_num = length(testset.id)
function retoken(wp, tk, tokens)
retokens = Array{String}(undef, 0)
wordbounds = Array{Int}(undef, 0)
_len = length(tokens)
sizehint!(retokens, _len)
sizehint!(wordbounds, _len)
for (i, token) in enumerate(tokens)
ntokens = wp(tk(token))
append!(retokens, ntokens)
foreach(_->push!(wordbounds, i), 1:length(ntokens))
end
sizehint!(retokens, length(retokens))
sizehint!(wordbounds, length(wordbounds))
# @assert wp(tk(join(tokens, ' '))) == retokens
return retokens, wordbounds
end
function getbatch(dataset, ids)
tks = dataset.tokens[ids]
chks = dataset.chunk_tags[ids]
poss = dataset.pos_tags[ids]
ners = dataset.ner_tags[ids]
return (token=tks, chunk=chks, pos=poss, ner=ners)
end
function relabel(wb, label, labels)
relabels = Vector{String}(undef, 0)
sizehint!(relabels, length(labels))
base = 1
@assert first(wb) == base
for i in wb
l = labels[i] + 1
if base == i
push!(relabels, label[l])
base += 1
else
push!(relabels, replace(label[l], r"^B"=>'I'))
end
end
return relabels
end
function preprocess(wordpiece, tokenizer, sample)
token, wb = retoken(wordpiece, tokenizer, sample.token)
chunk = relabel(wb, chunk_labels, sample.chunk)
pos = relabel(wb, pos_labels, sample.pos)
ner = relabel(wb, ner_labels, sample.ner)
return (token = token, chunk = chunk, pos = pos, ner = ner, bounds = wb)
end
function preprocess_batch(wordpiece, tokenizer, sample)
batch = length(sample.token)
token = Vector{Vector{String}}(undef, batch)
wb = Vector{Vector{Int}}(undef, batch)
chunk = similar(token)
pos = similar(token)
ner = similar(token)
for i = 1:batch
token[i], wb[i] = retoken(wordpiece, tokenizer, sample.token[i])
chunk[i] = relabel(wb[i], chunk_labels, sample.chunk[i])
pos[i] = relabel(wb[i], pos_labels, sample.pos[i])
ner[i] = relabel(wb[i], ner_labels, sample.ner[i])
end
return (token = token, chunk = chunk, pos = pos, ner = ner, bounds = wb)
end
addsstok(x, start_token = "[CLS]", sep_token = "[SEP]") = [start_token; x; sep_token]
function process(wordpiece, tokenizer, sample)
batch = preprocess_batch(wordpiece, tokenizer, sample)
token = batch.token
tok = map(addsstok, token)
mask = Basic.getmask(batch.token)
atten_mask = Basic.getmask(tok)
tok_id = vocab(tok)
segment = ones(Int, size(tok_id))
pos = Flux.onehot(pos_vocab, batch.pos)
chunk = Flux.onehot(chunk_vocab, batch.chunk)
ner = Flux.onehot(ner_vocab, batch.ner)
bounds = Tuple(batch.bounds)
return (input = (tok = tok_id, segment = segment), mask = mask, atten_mask = atten_mask,
pos = pos, chunk = chunk, ner = ner, bounds = bounds)
end