Subword Tokenization

Byte Pair Encoding (BPE) is a data compression algorithm that is commonly used in the context of subword tokenization for neural language models. BPE tokenizes text into smaller units, such as subword pieces or characters, to handle out-of-vocabulary words, reduce vocabulary size, and enhance the efficiency of language models.

Algorithm

The following describes the steps of BPE in terms of the EM algorithm:

  1. Initialization: Given a dictionary consisting of all words and their counts in a corpus, the symbol vocabulary is initialized by tokenizing each word into its most basic subword units, such as characters.

  2. Expectation: With the (updated) symbol vocabulary, it calculates the frequency of every symbol pair within the vocabulary.

  3. Maximization: Given all symbol pairs and their frequencies, it merges the top-k most frequent symbol pairs in the vocabulary.

  4. Steps 2 and 3 are repeated until meaningful sets of subwords are found for all words in the corpus.

The EM algorithm stands as a classic method in unsupervised learning. What are the advantages of unsupervised learning over supervised learning, and which tasks align well with unsupervised learning?

Implementation

Let us consider a toy vocabulary:

from src.types import WordCount, PairCount
EOW = '[EoW]'

word_counts = {
    'high': 12,
    'higher': 14,
    'highest': 10,
    'low': 12,
    'lower': 11,
    'lowest': 13
}

First, we create the symbol vocabulary by inserting a space between every pair of adjacent characters and adding a special symbol [EoW] at the end to indicate the End of the Word:

def initialize(word_counts: WordCount) -> WordCount:
    return {' '.join(list(word) + [EOW]): count for word, count in word_counts.items()}

Next, we count the frequencies of all symbol pairs in the vocabulary:

def expect(vocab: WordCount) -> PairCount:
    pairs = collections.defaultdict(int)

    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq

    return pairs

Finally, we update the vocabulary by merging the most frequent symbol pair across all words:

def maximize(vocab: WordCount, pairs: PairCount) -> WordCount:
    best = max(pairs, key=pairs.get)
    p = re.compile(r'(?<!\S)' + re.escape(' '.join(best)) + r'(?!\S)')
    return {p.sub(''.join(best), word): freq for word, freq in vocab.items()}

The expect() and maximize() can be repeated for multiple iterations until the tokenization becomes reasonable:

def bpe_vocab(word_counts: WordCount, max_iter: int):
    vocab = initialize(word_counts)

    for i in range(max_iter):
        pairs = expect(vocab)
        vocab = maximize(vocab, pairs)
        # print(vocab)

    return vocab

When you uncomment L7 in bpe_vocab(), you can see how the symbols are merged in each iteration:

bpe_vocab(word_counts, 10)

What are the disadvantages of using BPE-based tokenization instead of rule-based tokenization? What are the potential issues with the implementation of BPE above?

References

Source code: src/byte_pair_encoding.py

Last updated

Copyright © 2023 All rights reserved