Byte Pair Encoding (BPE) is a data compression algorithm that is commonly used in the context of subword tokenization for neural language models. BPE tokenizes text into smaller units, such as subword pieces or characters, to handle out-of-vocabulary words, reduce vocabulary size, and enhance the efficiency of language models.
Algorithm
The following describes the steps of BPE in terms of the EM algorithm:
Initialization: Given a dictionary consisting of all words and their counts in a corpus, the symbol vocabulary is initialized by tokenizing each word into its most basic subword units, such as characters.
Expectation: With the (updated) symbol vocabulary, it calculates the frequency of every symbol pair within the vocabulary.
Maximization: Given all symbol pairs and their frequencies, it merges the top-k most frequent symbol pairs in the vocabulary.
Steps 2 and 3 are repeated until meaningful sets of subwords are found for all words in the corpus.
Q4: The EM algorithm stands as a classic method in unsupervised learning. What are the advantages of unsupervised learning over supervised learning, and which tasks align well with unsupervised learning?
Implementation
Let us consider a toy vocabulary:
First, we create the symbol vocabulary by inserting a space between every pair of adjacent characters and adding a special symbol [EoW] at the end to indicate the End of the Word:
Next, we count the frequencies of all symbol pairs in the vocabulary:
Finally, we update the vocabulary by merging the most frequent symbol pair across all words:
The expect() and maximize() can be repeated for multiple iterations until the tokenization becomes reasonable:
When you uncomment L7 in bpe_vocab(), you can see how the symbols are merged in each iteration:
Q5: What are the disadvantages of using BPE-based tokenization instead of ? What are the potential issues with the implementation of BPE above?
def initialize(word_counts: WordCount) -> WordCount:
return {' '.join(list(word) + [EOW]): count for word, count in word_counts.items()}
def expect(vocab: WordCount) -> PairCount:
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[symbols[i], symbols[i + 1]] += freq
return pairs
def maximize(vocab: WordCount, pairs: PairCount) -> WordCount:
best = max(pairs, key=pairs.get)
p = re.compile(r'(?<!\S)' + re.escape(' '.join(best)) + r'(?!\S)')
return {p.sub(''.join(best), word): freq for word, freq in vocab.items()}
def bpe_vocab(word_counts: WordCount, max_iter: int):
vocab = initialize(word_counts)
for i in range(max_iter):
pairs = expect(vocab)
vocab = maximize(vocab, pairs)
# print(vocab)
return vocab
bpe_vocab(word_counts, 10)
{'hi g h [EoW]': 12, 'hi g h e r [EoW]': 14, 'hi g h e s t [EoW]': 10, 'l o w [EoW]': 12, 'l o w e r [EoW]': 11, 'l o w e s t [EoW]': 13}
{'hig h [EoW]': 12, 'hig h e r [EoW]': 14, 'hig h e s t [EoW]': 10, 'l o w [EoW]': 12, 'l o w e r [EoW]': 11, 'l o w e s t [EoW]': 13}
{'high [EoW]': 12, 'high e r [EoW]': 14, 'high e s t [EoW]': 10, 'l o w [EoW]': 12, 'l o w e r [EoW]': 11, 'l o w e s t [EoW]': 13}
{'high [EoW]': 12, 'high e r [EoW]': 14, 'high e s t [EoW]': 10, 'lo w [EoW]': 12, 'lo w e r [EoW]': 11, 'lo w e s t [EoW]': 13}
{'high [EoW]': 12, 'high e r [EoW]': 14, 'high e s t [EoW]': 10, 'low [EoW]': 12, 'low e r [EoW]': 11, 'low e s t [EoW]': 13}
{'high [EoW]': 12, 'high er [EoW]': 14, 'high e s t [EoW]': 10, 'low [EoW]': 12, 'low er [EoW]': 11, 'low e s t [EoW]': 13}
{'high [EoW]': 12, 'high er[EoW]': 14, 'high e s t [EoW]': 10, 'low [EoW]': 12, 'low er[EoW]': 11, 'low e s t [EoW]': 13}
{'high [EoW]': 12, 'high er[EoW]': 14, 'high es t [EoW]': 10, 'low [EoW]': 12, 'low er[EoW]': 11, 'low es t [EoW]': 13}
{'high [EoW]': 12, 'high er[EoW]': 14, 'high est [EoW]': 10, 'low [EoW]': 12, 'low er[EoW]': 11, 'low est [EoW]': 13}
{'high [EoW]': 12, 'high er[EoW]': 14, 'high est[EoW]': 10, 'low [EoW]': 12, 'low er[EoW]': 11, 'low est[EoW]': 13}