cs336-bpe 分词器

前言

这是 cs336 课程 assignment1 的第一个部分,要求实现一个 byte-level byte-pair encoding (BPE) 分词器

Byte-Level BPE 是 BPE 的一个变体,其最关键的区别在于:它的最基础单元是字节(Byte),而不是 Unicode 字符。一般我们说到的 BPE 是基于字符的 BPE,如果指明了是 Byte-Level BPE,那么它的基础单元就是字节。

基于字节的 BPE 的词表只有 256 个字节(0-255),几乎不会出现 oov 的问题。而基于字符的 BPE 的词表大小则取决于训练语料中出现的字符数量,通常会远大于 256,也会出现 oov 的问题。

Unicode Standard

在了解 BPE 之前,我们需要先了解 Unicode Standard。Unicode Standard 是一个字符编码标准,它为每个字符分配了一个唯一的代码点(code point)。代码点是一个整数,通常用十六进制表示,例如 U + 0041 表示字符’A’。

Unicode Encoding

Unicode 编码是将 Unicode 代码点转换为字节序列的过程。常见的 Unicode 编码有 UTF-8、UTF-16 和 UTF-32。UTF-8 是最常用的 Unicode 编码,它使用 1 到 4 个字节表示一个字符。UTF-8 的优点是兼容 ASCII 编码,并且对于常用字符使用较少的字节表示。

cs336 作业中的一段介绍: > While the Unicode standard defines a mapping from characters to code points (integers), it’s impractical to train tokenizers directly on Unicode codepoints, since the vocabulary would be prohibitively large (around 150K items) and sparse (since many characters are quite rare). Instead, we’ll use a Unicode encoding, which converts a Unicode character into a sequence of bytes

To encode a Unicode string into UTF-8, we can use the encode() function in Python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
>>> test_string = "hello! こんにちは!"
>>> utf8_encoded = test_string.encode("utf-8")
>>> print(utf8_encoded)
b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'
>>> print(type(utf8_encoded))
<class 'bytes'>
>>> # Get the byte values for the encoded string (integers from 0 to 255).
>>> list(utf8_encoded)
[104, 101, 108, 108, 111, 33, 32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129,
161, 227, 129, 175, 33]
>>> # One byte does not necessarily correspond to one Unicode character!
>>> print(len(test_string))
13
>>> print(len(utf8_encoded))
23
>>> print(utf8_encoded.decode("utf-8"))
hello! こんにちは!

By converting our Unicode codepoints into a sequence of bytes (e.g., via the UTF-8 encoding), we are essentially taking a sequence of codepoints (integers in the range 0 to 154,997) and transforming it into a sequence of byte values (integers in the range 0 to 255).

具体的编码规则不进行展开了

Subword Tokenization

While byte-level tokenization can alleviate the out-of-vocabulary issues faced by word-level tokenizers, tokenizing text into bytes results in extremely long input sequences. 虽然 byte-level tokenization 可以解决 word-level tokenizers 导致的 oov 的问题,但是这样会导致输入序列非常长。

Subword tokenization is a midpoint between word-level tokenizers and byte-level tokenizers.

我们会把那些经常出现在一起的字节对(byte pair)合并成一个新的子词(subword)。这样,我们就可以在保持较小词表的同时,减少输入序列的长度。

For example, if the byte sequence b’the’ often occurs in our raw text training data, assigning it an entry in the vocabulary would reduce this 3-token sequence to a single token.

BPE Tokenizer Training

BPE 分词器的训练主要分为 3 步

  1. Vocabulary initialization: 词表是一个从 bytestring 到 int 的一一映射。由于我们采用的是 byte-level BPE,所以词表的初始大小就是 256
  2. Pre-tokenization:
    • 理论上,有了初始的词表之后,我们就可以直接遍历所有的字节对,然后把出现频率最高的字节对合并成一个新的子词,加入到词表中。但是这样做效率非常低,因为每一次我们都要遍历整一个语料库。另外直接合并字节可能会产生跨单词边界的 token,或者创建出仅在标点符号上有差异的冗余 token(如 dog 和 dog! 被视为完全不同的 token),这浪费了词汇表空间且不符合语言规律。
    • 加上了预分词之后,我们可以先把语料库分成一个个单词,然后在每个单词内部进行 BPE 合并。这样做的好处是不需要再去原始的字节流中统计频率,而是可以直接统计这些 “预 token” 中各个字节对的频率。另外,预分词也能避免跨单词边界的 token(一般都不跨单词边界合并)。
    • 例子:如果单词” text” 在语料中出现了 10 次,那么当我们统计字节’t’和’e’在” text” 内部相邻的频率时,我们可以直接为这个词对 (t, e) 的频率加上 10,而不是在原始语料中扫描 10 次。这大大提升了统计效率。
  3. Compute BPE merges: 现在已经把输入变成了 pre-tokns 并且每一个 pre-token 都是一个 utf-8 编码的 bytestring。接下来我们就可以统计这些 pre-token 中各个字节对的频率,然后把出现频率最高的字节对合并成一个新的子词,加入到词表中。并且接下来要把所有包含这个字节对的 pre-token 都进行更新(合并这个字节对)。重复这个过程,直到词表达到指定的大小。作业中要求如果有多个字节对频率相同,那么就选择字典序最大的那个字节对进行合并。

另外需要注意 special tokens 的处理。special tokens 是一些特殊的 token,比如 [PAD]、[UNK]、[CLS]、[SEP]、<|endoftext|> 等。这些 token 在训练 BPE 分词器时需要被加入到词表中,并且不能被修改。

实现细节

Pre-tokenization

首先是预分词部分,这一部分是可以进行并行的。实验要求中说 pre-tokenize 是主要的性能瓶颈,而这一部分瓶颈是可以通过并行解决,思路就是把语料库分成多个 chunk,然后每个 chunk 交给一个线程进行预分词,最后把所有线程的结果合并起来。

然后注意这一部分是不能包含 special tokens 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def pre_tokenize(text: str, special_tokens: list[str] | None = None, preserve_special_tokens: bool = True) -> list[bytes]:
docs = split_by_special_tokens(text, special_tokens)
# print("\ndocs: ", docs) # ['Héllò hôw ', '<|endoftext|>', '', '<|endoftext|>', ' are ü? 🙃', '<|endoftext|>', '']
PAT = PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
token_bytes_list = []
for doc in docs:
if special_tokens and doc in special_tokens:
if preserve_special_tokens:
special_token_bytes = doc.encode("utf-8")
token_bytes_list.extend([special_token_bytes])
else:
tokens = re.findall(PAT, doc) # ['the', ' cat', ' ate']
token_bytes = [t.encode("utf-8") for t in tokens] # [b'the', b' cat', b' ate'].
token_bytes_list.extend(token_bytes)
# print("token_bytes_list:", token_bytes_list)
return token_bytes_list

BPE Merges

这一部分是整个 BPE 训练的核心部分,基本的算法也很直观。首先构造一个 word_cnt 字典,key 是 pre-token 的 bytestring,value 是这个 pre-token 在语料中出现的次数。然后统计所有 pre-token 中相邻字节对的频率,选择频率最高的字节对进行合并,更新 word_cnt,更新 merges 列表,更新 vocab。重复这个过程,直到词表达到指定的大小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
merges = []
for i in range(num_merges):
pair_cnt = count_pair(word_cnt)
max_pair = MAX(pair_cnt)
merges.append(max_pair)
vocab[base_vocab_size + i] = max_pair # 直接赋值

new_word_cnt = {}
for word_bytes, cnt in word_cnt.items():
j = 0
new_word_bytes = []
while j < len(word_bytes):
if j < len(word_bytes) - 1 and (word_bytes[j], word_bytes[j+1]) == max_pair:
new_word_bytes.append(max_pair[0] + max_pair[1])
j += 2
else:
new_word_bytes.append(word_bytes[j])
j += 1
new_word_bytes = tuple(new_word_bytes)
new_word_cnt[new_word_bytes] = new_word_cnt.get(new_word_bytes, 0) + cnt
word_cnt = new_word_cnt # 更新word_cnt

return vocab, merges

当然,上面的代码是没法过测试的,因为时间复杂度很高,有如下几个地方可以进行优化 1. 找到频率最高的字节对时,可以使用堆(heap)来优化 2. 更新 word_cnt 时,可以只更新包含 max_pair 的 pre-token,而不是所有的 pre-token,这里可以使用倒排索引来记录包含每个字节对的 pre-token 列表

总之,优化的思路就是使用更加高效的数据结构以及减少不必要的计算,完整的优化后代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def train_bpe(
input_path: str,
vocab_size: int,
special_tokens: list[str]
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
with open(input_path, "rb") as f:
num_processes = 8
boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

# The following is a serial implementation, but you can parallelize this
# by sending each start/end pair to a set of processes.
chunks = []
for start, end in zip(boundaries[:-1], boundaries[1:]):
f.seek(start)
chunk = f.read(end - start).decode("utf-8", errors="ignore")
# Run pre-tokenization on your chunk and store the counts for each pre-token
chunks.append(chunk)

args = [(chunk, special_tokens) for chunk in chunks]

with multiprocessing.Pool(processes=num_processes) as pool:
results = pool.starmap(process_chunk, args)

word_cnt = Counter()
for result in results:
word_cnt.update(result)

vocab = get_basic_vocab(special_tokens)
base_vocab_size = len(vocab)
num_merges = vocab_size - base_vocab_size

pair_cnt, pair2word_bytes = count_pair(word_cnt)

merges = []
for i in range(num_merges):

if i % 100 == 0: # 每100轮重建一次堆,避免堆变得太大
heap = MaxHeap()
for pair, cnt in pair_cnt.items():
heap.push((cnt, pair))

# 懒惰删除
while True:
cnt, pair = heap.pop()
# print(cnt, pair)
if pair in pair_cnt and pair_cnt[pair] == cnt:
max_pair = pair
break
# 否则丢弃,继续弹出

merges.append(max_pair)
vocab[base_vocab_size + i] = max_pair[0] + max_pair[1]

# 更新word_cnt
# 只需要改动那些max_pair合并影响到的
affected_word_bytes = pair2word_bytes[max_pair] # set
pair2word_bytes.pop(max_pair)
affected_pairs = set()
for word_bytes in affected_word_bytes:
cnt = word_cnt[word_bytes]
word_cnt.pop(word_bytes)

for pair in zip(word_bytes[:-1], word_bytes[1:]):
pair_cnt[pair] -= cnt
if pair_cnt[pair] == 0:
del pair_cnt[pair]
pair2word_bytes[pair].discard(word_bytes)
if not pair2word_bytes[pair]:
del pair2word_bytes[pair]

affected_pairs.add(pair) # 因为有些pair及时变少了,多轮以后可能也是最大的,所以这里必须要追踪

j = 0
new_word_bytes = []
while j < len(word_bytes):
if j < len(word_bytes) - 1 and (word_bytes[j], word_bytes[j+1]) == max_pair: # 遇到要合并的情况
new_word_bytes.append(max_pair[0] + max_pair[1])
j += 2
else:
new_word_bytes.append(word_bytes[j])
j += 1
new_word_bytes = tuple(new_word_bytes)
word_cnt[new_word_bytes] += cnt


for i, pair in enumerate(zip(new_word_bytes[:-1], new_word_bytes[1:])):
pair_cnt[pair] += cnt
pair2word_bytes[pair].add(new_word_bytes)
affected_pairs.add(pair)

for pair in affected_pairs:
if pair in pair_cnt:
heap.push((pair_cnt[pair], pair))

return vocab, merges

这段代码跑完第一个测试点只需要 0.3s,远超过 1.5s 的要求

BPE Tokenizer Encoding and Decoding

BPE 分词器的编码和解码相对简单,主要是根据训练好的词表和 merges 进行操作。

Encoding

  1. Pre-tokenization: 这一部分和训练时的预分词是一样的,都是把输入文本分成一个个 pre-token。
  2. Apply the merges: 对每个 pre-token,先拆成一个个的字节,然后根据 merges 列表的顺序进行合并,直到不能再合并为止。最后把所有的 token 的 id 拼接起来,作为最终的编码结果。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    # 必须按照词表的顺序
    for merge in self.merges:
    new_token_bytes = []
    j = 0

    # print(merge)
    while j < len(token_bytes):
    if j < len(token_bytes) - 1 and (token_bytes[j], token_bytes[j+1]) == merge:
    new_token_bytes.append(token_bytes[j] + token_bytes[j+1])
    j += 2
    else:
    new_token_bytes.append(token_bytes[j])
    j += 1

    # print(new_token_bytes)
    token_bytes = new_token_bytes
    由此可见,每一个 token 都要把 merge 表从头到尾遍历一遍,时间复杂度是 O (n * m)

另外需要注意,special tokens 需要被保留

Decoding

解码相对简单,就是把 token id 转换成 bytes(通过 vocab),然后拼接成一个 bytestring

完整的 BPE Tokenizer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class BPETokenizer:
def __init__(
self,
vocab: dict[int, bytes],
merges: list[tuple[bytes, bytes]],
special_tokens: list[str] | None = None
):
self.vocab = vocab
self.merges = merges
self.special_tokens = special_tokens

def from_files(
cls,
vocab_filepath: str,
merges_filepath: str,
special_tokens: list[str] | None = None
):
with open(vocab_filepath) as vocab_f:
vocab = json.load(vocab_f)
merges = []
with open(merges_filepath) as f:
for line in f:
cleaned_line = line.rstrip()
if cleaned_line and len(cleaned_line.split(" ")) == 2:
merges.append(tuple(cleaned_line.split(" ")))

if special_tokens:
for special_token in special_tokens:
byte_encoded_special_token = special_token.encode("utf-8")
if byte_encoded_special_token not in set(vocab.values()):
vocab[len(vocab)] = byte_encoded_special_token

vocab = {
vocab_index: bytes([c for c in vocab_item])
for vocab_index, vocab_item in vocab.items()
}

merges = [
(
bytes([token for token in merge_token_1]),
bytes([token for token in merge_token_2]),
)
for merge_token_1, merge_token_2 in merges
]
tokenizer = BPETokenizer(vocab, merges, special_tokens)
return tokenizer


def encode(self, text: str) -> list[int]:
vocab_reversed = {v: k for k, v in self.vocab.items()}
token_bytes_list = pre_tokenize(text, self.special_tokens) # list[bytes]

byte_special_tokens = []
if self.special_tokens:
byte_special_tokens = [special_token.encode('utf-8') for special_token in self.special_tokens]
# print("byte_special_tokens:", byte_special_tokens)

# print((b'He', b'llo') in merges)
new_token_bytes_list = []
for token_bytes in token_bytes_list: # bytes e.g. b'the'
# print("initial token_bytes:", token_bytes)

if token_bytes in byte_special_tokens:
new_token_bytes_list.append(token_bytes)
# print("special: ", token_bytes)
continue
# 转换为list[bytes]
token_bytes = [bytes([byte]) for byte in token_bytes]
# print("token_bytes:", token_bytes)
# print(token_bytes[0])


# 必须按照词表的顺序
for merge in self.merges:
new_token_bytes = []
j = 0

# print(merge)
while j < len(token_bytes):
if j < len(token_bytes) - 1 and (token_bytes[j], token_bytes[j+1]) == merge:
new_token_bytes.append(token_bytes[j] + token_bytes[j+1])
j += 2
else:
new_token_bytes.append(token_bytes[j])
j += 1

# print(new_token_bytes)
token_bytes = new_token_bytes
# print(len(new_token_bytes))
new_token_bytes_list.extend(new_token_bytes)
# print("new_token_bytes_list after merged: ", new_token_bytes_list)

new_token_ids_list = [vocab_reversed[i] for i in new_token_bytes_list]
# print(new_token_ids_list)
return new_token_ids_list

def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
"""
Given an iterable of
strings (e.g., a Python file handle), return a generator that lazily yields token IDs. This is
required for memory-efficient tokenization of large files that we cannot directly load into memory
"""
for line in iterable:
for idx in self.encode(line):
yield idx


def decode(self, ids: list[int]) -> str:

# print("ids:", ids)
result = []
# print(type(list(self.vocab.keys())[0])) # int
for id in ids:
# print(self.vocab[id].decode("utf-8"))
# print(type(self.vocab.get(id, None)))
result.extend(self.vocab.get(id, None))
# print("decode result:", result)
result = bytes(result).decode("utf-8", errors="replace")
# print("decode result:", result)
return result
正在加载今日诗词....
欢迎关注我的其它发布渠道