Writing a fast nlp tokenizer in Julia

Hello everyone,

As a learning exercise, I’m interested in writing a byte-pair-encoding (BPE) tokenizer in Julia.
The meat of this tokenizer would be to find the most occuring pairs of tokens (initially these are just the individual characters) in a big text file and merge them into a single token.

I’ve looked at SentencePiece and I’d like the tokenizer to operate on raw text where the spaces are replaced by _ and treated as a regular character of the stream, like they do.
Now I’m wondering how I should represent tokens and how I would go about the merging operation.
I think I’ll definitely need a (sorted?) dictionary to keep memory of the (pairs of?) tokens and their number of occurences in the text.
then I see two naive ways of going about this:

  • store the text as one big string where tokens are delimited by a special character and find and replace the most frequent pairs using RegEx.
  • assign a number to each token and convert the text to an array of these numbers

I originally asked this question on Slack where Scott Paul Jones already replied that RegEx probably wasn’t a good idea from a performance perspective.
The other approach seems really inelegant to me and I also wouldn’t know how to performantly “merge” two array elements together.
On Slack, Jeffrey Sarnoff suggested the use of symbols instead of strings these apparently aren’t garbage-collected.

I’m also aware of BytePairEncoding.jl which essentially does what I want except that it doesn’t create tokens that cross word-barriers. And I think, there, tokens are represented as strings.

I’m sure there’s a lot of different ideas to tackle this problem so I’d love hear how I could do this efficiently.

Thanks in advance,
Jules

You could use codeunits to treat any string as a vector of Unicode codepoints, so you don’t have to replace anything. You could then fill an array of the same length like your codeunits vector with token “membership” information. This should be reasonably efficient and avoids repeatedly inserting and deleting stuff.

Now I don’t know about BPE, does it allow for more than 2 adjacent characters to become a token or is it only searching for pairs?
In the latter case, iterating over all codeunits once should give allow you to create a co-occurence matrix pretty efficiently from which you should be able to compute the best pairs that you want to join to a common token

Thanks for the response.
Codeunits do indeed look promising, didn’t know about those! Although I’m not quite sure what you mean by the membership information?
As for the bpe algorithm, there seems to be a lot of variants but as I see it every step in the algorithm merges exactly one pair of tokens together.
Thanks again!

That was badly worded, sorry. I meant a vector indicating how long a token starting at this position is. It’s easy to manipulate and can be used to gather the resulting tokenized string pretty efficiently.

Btw, I tried that myself for fun. Would this be correct?

julia> println(teststring)
This is a test. It contains some test words for redundancy and blablablabla...

julia> show(pairencode(teststring))
[['T', 'h'], ['i'], ['s', ' '], ['i'], ['s', ' '], ['a'], [' ', 't'], ['e', 's'], ['t', '.'], [' ', 'I'], ['t', ' '], ['c'], ['o', 'n'], ['t', 'a'], ['i', 'n'], ['s', ' '], ['s', 'o'], ['m', 'e'], [' ', 't'], ['e', 's'], ['t', ' '], ['w'], ['o', 'r'], ['d'], ['s', ' '], ['f'], ['o', 'r'], [' ', 'r'], ['e', 'd'], ['u'], ['n', 'd'], ['a', 'n'], ['c'], ['y', ' '], ['a', 'n'], ['d', ' '], ['b'], ['l', 'a'], ['b'], ['l', 'a'], ['b'], ['l', 'a'], ['b'], ['l', 'a'], ['.', '.'], ['.']]

Or should it be able to merge already merged tokens again?

Thanks for your response!

Or should it be able to merge already merged tokens again?

Yes, I’ve tried writing a version with codeunits. A token is represented as a NTuple of CodeUnits and I use your “membership” idea to easily merge tokens together.
It seems to work OK, I’ve not worried about some edge cases such as repeated tokens and merging tokens at the beginning and end of the text.

The performance on a file of a few thousand lines leaves quite a lot to be desired though. I think the biggest problem now is that a token is represented as an arbitrarily long tuple so type inference might have some trouble?

using DataStructures: CircularBuffer, DefaultDict

data = codeunits(replace("This is a test. It contains some test words for redundancy and blablablabla...", " "=>"_"))


membership = zeros(Int, length(data)) # each starting index gives the distance to the end of the current token

function get_occurences(data, membership)
	occurences = DefaultDict(0)
	
	i = Int(1) # start of the first token
	while i < length(data)
		token1 = data[i:i+membership[i]]
		i += membership[i] + 1 # start of the next token
		token2 = data[i:i+membership[i]]
		pair = (token1..., token2...)
		
		occurences[pair] += 1
	end
	
	return occurences
end

occurences = get_occurences(data, membership)

function update!(occurences, membership, data)
	new_token = findmax(occurences)[2]
	println("merging:\t$(new_token)")
	delete!(occurences, new_token)
	
	buffer = CircularBuffer(4)
	i=1
	while !isfull(buffer) # initially fill buffer with first 4 tokens.
		push!(buffer, (i, data[i:i+membership[i]]))
		i += membership[i]+1
	end
	
	# I ignore the edge case of merging the first two, or last two, tokens.
	
	while i < length(data)
		pair = (buffer[2][2]..., buffer[3][2]...)
		if pair == new_token
			membership[buffer[2][1]] += 1 + membership[buffer[3][1]] # merge
			
			# update occurences:
			occurences[(buffer[1][2]..., buffer[2][2]...)] -= 1
			occurences[(buffer[1][2]..., new_token...)] += 1
						
			occurences[(buffer[3][2]..., buffer[4][2]...)] -= 1
			occurences[(new_token..., buffer[4][2]...)] += 1
		end
		i += membership[i]+1
		push!(buffer, (i, data[i:i+membership[i]]))
	end
end

begin
	n_iters = 10 # for a large textfile this might be some thousand times?
	for _ in 1:n_iters; update!(occurences, membership, data); end
end

function get_tokens(data, membership)
	i = 1
	tokenized = []
	while i < length(data)
		push!(tokenized, String([data[i:i+membership[i]]...]))
		i += membership[i]+1
	end
	tokenized
end

get_tokens(data, membership) # returns the tokenized data

unique(get_tokens(data, membership)) # return all the individual tokens that you end up with.

I see, didn’t think about that in my initial try.

It’s certainly straightforward to code, but I don’t know if it is the most efficient approach? Have you looked at the literature to see if there are some clever tricks? Or are you coding freely, too?

For reference: do you know at how fast we’re aiming at? We could do a little profiling to identify bottlenecks and see if we can make this run.

Quite possible, although Tuple{UInt8, Vararg{UInt8}} is a thing, have you tried to look at @code_warntype?

1 Like

Have you looked at the literature to see if there are some clever tricks?

I’ve read some literature but not a lot. Most of them leave these things as an implementation detail.
e.g. this one (pdf):
afbeelding.
Also I’m not aiming for the absolute fastest code, but your ideas have sped things up already by quite a bit.

For reference: do you know at how fast we’re aiming at? We could do a little profiling to identify bottlenecks and see if we can make this run.

No, I’ve got no real intuition as to how fast this could potentially run, Huggingface’s bpe training tutorial claims wikitext103 (200mb) takes a few seconds! This is surely some orders of magnitude faster then what I’ve written but again, I don’t really expect to even get close to those numbers.

As for @code_warntype, I saw a lot of red Any which is what prompted my suspicion.

  #self#::Core.Compiler.Const(update!, false)
  occurences::DefaultDict{Any,Any,Int64}
  membership::Array{Int64,1}
  data::Base.CodeUnits{UInt8,String}
  new_token::Any
  buffer::CircularBuffer{Any}
  i::Int64
  pair::Tuple

Body::Nothing
1 ── %1   = Main.findmax(occurences)::Tuple{Any,Any}
│           (new_token = Base.getindex(%1, 2))
│    %3   = Base.string("merging:\t", new_token)::String
│           Main.println(%3)
│           Main.delete!(occurences, new_token)
│           (buffer = Main.CircularBuffer(4))
└───        (i = 1)
2 ┄─ %8   = Main.isfull(buffer)::Bool
│    %9   = !%8::Bool
└───        goto #4 if not %9
3 ── %11  = buffer::CircularBuffer{Any}
│    %12  = i::Int64
│    %13  = i::Int64
│    %14  = i::Int64
│    %15  = Base.getindex(membership, i)::Int64
│    %16  = (%14 + %15)::Int64
│    %17  = (%13:%16)::UnitRange{Int64}
│    %18  = Base.getindex(data, %17)::Array{UInt8,1}
│    %19  = Core.tuple(%12, %18)::Tuple{Int64,Array{UInt8,1}}
│           Main.push!(%11, %19)
│    %21  = i::Int64
│    %22  = Base.getindex(membership, i)::Int64
│    %23  = (%22 + 1)::Int64
│           (i = %21 + %23)
└───        goto #2
4 ── %26  = Base.getindex(buffer, 1)::Any
│    %27  = Base.getindex(%26, 2)::Any
│    %28  = Base.getindex(buffer, 2)::Any
│    %29  = Base.getindex(%28, 2)::Any
│    %30  = Core._apply_iterate(Base.iterate, Core.tuple, %27, %29)::Tuple
│    %31  = (%30 == new_token)::Any
└───        goto #6 if not %31
5 ── %33  = Base.getindex(membership, 1)::Int64
│    %34  = Base.getindex(buffer, 2)::Any
│    %35  = Base.getindex(%34, 1)::Any
│    %36  = Base.getindex(membership, %35)::Any
│    %37  = (1 + %36)::Any
│    %38  = (%33 + %37)::Any
│           Base.setindex!(membership, %38, 1)
│    %40  = Base.getindex(buffer, 2)::Any
│    %41  = Base.getindex(%40, 2)::Any
│    %42  = Base.getindex(buffer, 3)::Any
│    %43  = Base.getindex(%42, 2)::Any
│    %44  = Core._apply_iterate(Base.iterate, Core.tuple, %41, %43)::Tuple
│    %45  = Base.getindex(occurences, %44)::Any
│    %46  = (%45 - 1)::Any
│    %47  = Core._apply_iterate(Base.iterate, Core.tuple, %41, %43)::Tuple
│           Base.setindex!(occurences, %46, %47)
│    %49  = Base.getindex(buffer, 3)::Any
│    %50  = Base.getindex(%49, 2)::Any
│    %51  = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %50)::Tuple
│    %52  = Base.getindex(occurences, %51)::Any
│    %53  = (%52 + 1)::Any
│    %54  = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %50)::Tuple
└───        Base.setindex!(occurences, %53, %54)
6 ┄─ %56  = i::Int64
│    %57  = Main.length(data)::Int64
│    %58  = (%56 < %57)::Bool
└───        goto #10 if not %58
7 ── %60  = Base.getindex(buffer, 2)::Any
│    %61  = Base.getindex(%60, 2)::Any
│    %62  = Base.getindex(buffer, 3)::Any
│    %63  = Base.getindex(%62, 2)::Any
│           (pair = Core._apply_iterate(Base.iterate, Core.tuple, %61, %63))
│    %65  = (pair == new_token)::Any
└───        goto #9 if not %65
8 ── %67  = Base.getindex(buffer, 2)::Any
│    %68  = Base.getindex(%67, 1)::Any
│    %69  = Base.getindex(membership, %68)::Any
│    %70  = Base.getindex(buffer, 3)::Any
│    %71  = Base.getindex(%70, 1)::Any
│    %72  = Base.getindex(membership, %71)::Any
│    %73  = (1 + %72)::Any
│    %74  = (%69 + %73)::Any
│           Base.setindex!(membership, %74, %68)
│    %76  = Base.getindex(buffer, 1)::Any
│    %77  = Base.getindex(%76, 2)::Any
│    %78  = Base.getindex(buffer, 2)::Any
│    %79  = Base.getindex(%78, 2)::Any
│    %80  = Core._apply_iterate(Base.iterate, Core.tuple, %77, %79)::Tuple
│    %81  = Base.getindex(occurences, %80)::Any
│    %82  = (%81 - 1)::Any
│    %83  = Core._apply_iterate(Base.iterate, Core.tuple, %77, %79)::Tuple
│           Base.setindex!(occurences, %82, %83)
│    %85  = Base.getindex(buffer, 1)::Any
│    %86  = Base.getindex(%85, 2)::Any
│    %87  = Core._apply_iterate(Base.iterate, Core.tuple, %86, new_token)::Tuple
│    %88  = Base.getindex(occurences, %87)::Any
│    %89  = (%88 + 1)::Any
│    %90  = Core._apply_iterate(Base.iterate, Core.tuple, %86, new_token)::Tuple
│           Base.setindex!(occurences, %89, %90)
│    %92  = Base.getindex(buffer, 3)::Any
│    %93  = Base.getindex(%92, 2)::Any
│    %94  = Base.getindex(buffer, 4)::Any
│    %95  = Base.getindex(%94, 2)::Any
│    %96  = Core._apply_iterate(Base.iterate, Core.tuple, %93, %95)::Tuple
│    %97  = Base.getindex(occurences, %96)::Any
│    %98  = (%97 - 1)::Any
│    %99  = Core._apply_iterate(Base.iterate, Core.tuple, %93, %95)::Tuple
│           Base.setindex!(occurences, %98, %99)
│    %101 = Base.getindex(buffer, 4)::Any
│    %102 = Base.getindex(%101, 2)::Any
│    %103 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %102)::Tuple
│    %104 = Base.getindex(occurences, %103)::Any
│    %105 = (%104 + 1)::Any
│    %106 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %102)::Tuple
└───        Base.setindex!(occurences, %105, %106)
9 ┄─ %108 = i::Int64
│    %109 = Base.getindex(membership, i)::Int64
│    %110 = (%109 + 1)::Int64
│           (i = %108 + %110)
│    %112 = buffer::CircularBuffer{Any}
│    %113 = i::Int64
│    %114 = i::Int64
│    %115 = i::Int64
│    %116 = Base.getindex(membership, i)::Int64
│    %117 = (%115 + %116)::Int64
│    %118 = (%114:%117)::UnitRange{Int64}
│    %119 = Base.getindex(data, %118)::Array{UInt8,1}
│    %120 = Core.tuple(%113, %119)::Tuple{Int64,Array{UInt8,1}}
│           Main.push!(%112, %120)
└───        goto #6
10 ─        return```

You should probably take a look at

Both because it may be what you need already.
And because it’s implemented the data structures and API for fast parsing.
With the TokenBuffer API

You can check the paper.

It’s >4x faster than Spacy, and >6x faster than NLTK.

I suspect HuggingFace’s new tokenisers are faster still.
But not evaluated.

Though it is mostly rule-based tokenisers.
Especially as far as performance optimisation is concerned.
So might not be so relevant.

3 Likes

Oh that’s a lot of anys for sure. You could start by typing buffer and membership and see if that fixes inference. And maybe add some function barriers. I’ll take a look at it later if you don’t mind.

Pretty cool, @merckxiaan may probably add to WordTokenizers.jl if it turns out to be performant.

Very nice basis you have there.

1 Like

Thanks oxinabox,

I’ve seen WordTokenizers and it looks really interesting. I read the gsoc blog post on sentencepiece and understood that WordTokenizers doesn’t do training of tokenizers, right?
It’s really fascinating that this readable Julia code leads to something faster then Spacy though, I’ll be sure to take a closer look at the implementation.

You could start by typing buffer and membership and see if that fixes inference.

Thanks!
I’ll try your suggestions.

I’ll take a look at it later if you don’t mind.

I don’t mind at all! I’ve already learned a lot, no pressure whatsoever! :slight_smile:

It’s >4x faster than Spacy, and >6x faster than NLTK.

I’m not sure if things changed, but spaCy v3.0.0 was out yesterday. For this or other problems good to have in mind, as it claims state-of-the-art. The Julia spaCy wrapper hasn’t been updated for 3 year, maybe somebody needs to take over. GitHub - jekbradbury/SpaCy.jl: Julia interface for SpaCy NLP library