Have you looked at the literature to see if there are some clever tricks?
I’ve read some literature but not a lot. Most of them leave these things as an implementation detail.
e.g. this one (pdf):
.
Also I’m not aiming for the absolute fastest code, but your ideas have sped things up already by quite a bit.
For reference: do you know at how fast we’re aiming at? We could do a little profiling to identify bottlenecks and see if we can make this run.
No, I’ve got no real intuition as to how fast this could potentially run, Huggingface’s bpe training tutorial claims wikitext103 (200mb) takes a few seconds! This is surely some orders of magnitude faster then what I’ve written but again, I don’t really expect to even get close to those numbers.
As for @code_warntype
, I saw a lot of red Any
which is what prompted my suspicion.
#self#::Core.Compiler.Const(update!, false)
occurences::DefaultDict{Any,Any,Int64}
membership::Array{Int64,1}
data::Base.CodeUnits{UInt8,String}
new_token::Any
buffer::CircularBuffer{Any}
i::Int64
pair::Tuple
Body::Nothing
1 ── %1 = Main.findmax(occurences)::Tuple{Any,Any}
│ (new_token = Base.getindex(%1, 2))
│ %3 = Base.string("merging:\t", new_token)::String
│ Main.println(%3)
│ Main.delete!(occurences, new_token)
│ (buffer = Main.CircularBuffer(4))
└─── (i = 1)
2 ┄─ %8 = Main.isfull(buffer)::Bool
│ %9 = !%8::Bool
└─── goto #4 if not %9
3 ── %11 = buffer::CircularBuffer{Any}
│ %12 = i::Int64
│ %13 = i::Int64
│ %14 = i::Int64
│ %15 = Base.getindex(membership, i)::Int64
│ %16 = (%14 + %15)::Int64
│ %17 = (%13:%16)::UnitRange{Int64}
│ %18 = Base.getindex(data, %17)::Array{UInt8,1}
│ %19 = Core.tuple(%12, %18)::Tuple{Int64,Array{UInt8,1}}
│ Main.push!(%11, %19)
│ %21 = i::Int64
│ %22 = Base.getindex(membership, i)::Int64
│ %23 = (%22 + 1)::Int64
│ (i = %21 + %23)
└─── goto #2
4 ── %26 = Base.getindex(buffer, 1)::Any
│ %27 = Base.getindex(%26, 2)::Any
│ %28 = Base.getindex(buffer, 2)::Any
│ %29 = Base.getindex(%28, 2)::Any
│ %30 = Core._apply_iterate(Base.iterate, Core.tuple, %27, %29)::Tuple
│ %31 = (%30 == new_token)::Any
└─── goto #6 if not %31
5 ── %33 = Base.getindex(membership, 1)::Int64
│ %34 = Base.getindex(buffer, 2)::Any
│ %35 = Base.getindex(%34, 1)::Any
│ %36 = Base.getindex(membership, %35)::Any
│ %37 = (1 + %36)::Any
│ %38 = (%33 + %37)::Any
│ Base.setindex!(membership, %38, 1)
│ %40 = Base.getindex(buffer, 2)::Any
│ %41 = Base.getindex(%40, 2)::Any
│ %42 = Base.getindex(buffer, 3)::Any
│ %43 = Base.getindex(%42, 2)::Any
│ %44 = Core._apply_iterate(Base.iterate, Core.tuple, %41, %43)::Tuple
│ %45 = Base.getindex(occurences, %44)::Any
│ %46 = (%45 - 1)::Any
│ %47 = Core._apply_iterate(Base.iterate, Core.tuple, %41, %43)::Tuple
│ Base.setindex!(occurences, %46, %47)
│ %49 = Base.getindex(buffer, 3)::Any
│ %50 = Base.getindex(%49, 2)::Any
│ %51 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %50)::Tuple
│ %52 = Base.getindex(occurences, %51)::Any
│ %53 = (%52 + 1)::Any
│ %54 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %50)::Tuple
└─── Base.setindex!(occurences, %53, %54)
6 ┄─ %56 = i::Int64
│ %57 = Main.length(data)::Int64
│ %58 = (%56 < %57)::Bool
└─── goto #10 if not %58
7 ── %60 = Base.getindex(buffer, 2)::Any
│ %61 = Base.getindex(%60, 2)::Any
│ %62 = Base.getindex(buffer, 3)::Any
│ %63 = Base.getindex(%62, 2)::Any
│ (pair = Core._apply_iterate(Base.iterate, Core.tuple, %61, %63))
│ %65 = (pair == new_token)::Any
└─── goto #9 if not %65
8 ── %67 = Base.getindex(buffer, 2)::Any
│ %68 = Base.getindex(%67, 1)::Any
│ %69 = Base.getindex(membership, %68)::Any
│ %70 = Base.getindex(buffer, 3)::Any
│ %71 = Base.getindex(%70, 1)::Any
│ %72 = Base.getindex(membership, %71)::Any
│ %73 = (1 + %72)::Any
│ %74 = (%69 + %73)::Any
│ Base.setindex!(membership, %74, %68)
│ %76 = Base.getindex(buffer, 1)::Any
│ %77 = Base.getindex(%76, 2)::Any
│ %78 = Base.getindex(buffer, 2)::Any
│ %79 = Base.getindex(%78, 2)::Any
│ %80 = Core._apply_iterate(Base.iterate, Core.tuple, %77, %79)::Tuple
│ %81 = Base.getindex(occurences, %80)::Any
│ %82 = (%81 - 1)::Any
│ %83 = Core._apply_iterate(Base.iterate, Core.tuple, %77, %79)::Tuple
│ Base.setindex!(occurences, %82, %83)
│ %85 = Base.getindex(buffer, 1)::Any
│ %86 = Base.getindex(%85, 2)::Any
│ %87 = Core._apply_iterate(Base.iterate, Core.tuple, %86, new_token)::Tuple
│ %88 = Base.getindex(occurences, %87)::Any
│ %89 = (%88 + 1)::Any
│ %90 = Core._apply_iterate(Base.iterate, Core.tuple, %86, new_token)::Tuple
│ Base.setindex!(occurences, %89, %90)
│ %92 = Base.getindex(buffer, 3)::Any
│ %93 = Base.getindex(%92, 2)::Any
│ %94 = Base.getindex(buffer, 4)::Any
│ %95 = Base.getindex(%94, 2)::Any
│ %96 = Core._apply_iterate(Base.iterate, Core.tuple, %93, %95)::Tuple
│ %97 = Base.getindex(occurences, %96)::Any
│ %98 = (%97 - 1)::Any
│ %99 = Core._apply_iterate(Base.iterate, Core.tuple, %93, %95)::Tuple
│ Base.setindex!(occurences, %98, %99)
│ %101 = Base.getindex(buffer, 4)::Any
│ %102 = Base.getindex(%101, 2)::Any
│ %103 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %102)::Tuple
│ %104 = Base.getindex(occurences, %103)::Any
│ %105 = (%104 + 1)::Any
│ %106 = Core._apply_iterate(Base.iterate, Core.tuple, new_token, %102)::Tuple
└─── Base.setindex!(occurences, %105, %106)
9 ┄─ %108 = i::Int64
│ %109 = Base.getindex(membership, i)::Int64
│ %110 = (%109 + 1)::Int64
│ (i = %108 + %110)
│ %112 = buffer::CircularBuffer{Any}
│ %113 = i::Int64
│ %114 = i::Int64
│ %115 = i::Int64
│ %116 = Base.getindex(membership, i)::Int64
│ %117 = (%115 + %116)::Int64
│ %118 = (%114:%117)::UnitRange{Int64}
│ %119 = Base.getindex(data, %118)::Array{UInt8,1}
│ %120 = Core.tuple(%113, %119)::Tuple{Int64,Array{UInt8,1}}
│ Main.push!(%112, %120)
└─── goto #6
10 ─ return```