Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 202 additions & 1 deletion bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bpe

import (
"bufio"
"container/heap"
"encoding/binary"
"errors"
"io"
Expand All @@ -14,6 +15,9 @@ import (
// TokenID is a numerical identifier of the subword token
type TokenID uint32

// TokenIDPair is a concatenation of two TokenIDs that is used as the key type in rule2id map.
type TokenIDPair uint64

// EncodedString is a sequence of subword token identifiers
type EncodedString []TokenID

Expand All @@ -24,6 +28,13 @@ const (
eosToken = "<EOS>"
)

// EncodingConfig is a configuration for encoding of strings
type EncodingConfig struct {
bos bool
eos bool
reverse bool
}

type rule struct {
left TokenID
right TokenID
Expand All @@ -43,6 +54,7 @@ type Model struct {
char2id map[rune]TokenID
id2char map[TokenID]rune
rules []rule
rule2id map[TokenIDPair]int
recipe map[TokenID]EncodedString
revRecipe map[string]TokenID
specialTokens specialTokens
Expand All @@ -54,13 +66,18 @@ func newModel(nRules int) *Model {
make(map[rune]TokenID),
make(map[TokenID]rune),
make([]rule, nRules),
make(map[TokenIDPair]int),
make(map[TokenID]EncodedString),
make(map[string]TokenID),
specialTokens{-1, -1, -1, -1},
0,
}
}

func newTokenIDPair(left, right TokenID) TokenIDPair {
return (TokenIDPair(left) << 32) + TokenIDPair(right)
}

// DecodeToken converts the sequence of chars' ids into the string -
// sequence of the corresponding chars
func DecodeToken(token EncodedString, id2char map[TokenID]rune) (string, error) {
Expand Down Expand Up @@ -167,7 +184,6 @@ func ReadModel(reader io.Reader) (*Model, error) {
if err != nil {
return model, err
}
model.rules[i] = rule
if _, ok := model.recipe[rule.left]; !ok {
logrus.Errorf("%d: token id not described before", rule.left)
return model, errors.New("token id is impossible")
Expand All @@ -176,6 +192,8 @@ func ReadModel(reader io.Reader) (*Model, error) {
logrus.Errorf("%d: token id not described before", rule.right)
return model, errors.New("token id is impossible")
}
model.rules[i] = rule
model.rule2id[newTokenIDPair(rule.left, rule.right)] = i
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
if err != nil {
Expand All @@ -194,6 +212,10 @@ func ReadModel(reader io.Reader) (*Model, error) {
return model, err
}
model.specialTokens = specials
model.revRecipe[bosToken] = TokenID(specials.bos)
model.revRecipe[eosToken] = TokenID(specials.eos)
model.revRecipe[unkToken] = TokenID(specials.unk)
model.revRecipe[padToken] = TokenID(specials.pad)
return model, err
}

Expand Down Expand Up @@ -287,3 +309,182 @@ func (m Model) DecodeFromStream(reader io.Reader) ([]string, error) {
}
return sentences, nil
}

type encodingToken struct {
id TokenID
prev int
next int
}

type mergeEvent struct {
priority int
pos int
}

type mergeQueue []*mergeEvent

func (mq mergeQueue) Len() int { return len(mq) }

func (mq mergeQueue) Less(i, j int) bool {
return mq[i].priority < mq[j].priority ||
mq[i].priority == mq[j].priority && mq[i].pos < mq[j].pos
}

func (mq mergeQueue) Swap(i, j int) {
mq[i], mq[j] = mq[j], mq[i]
}

func (mq *mergeQueue) Push(x interface{}) {
*mq = append(*mq, x.(*mergeEvent))
}

func (mq *mergeQueue) Pop() interface{} {
old := *mq
n := len(old)
item := old[n-1]
old[n-1] = nil // avoid memory leak
*mq = old[0 : n-1]
return item
}

// EncodeSentence takes a string of space-separated words and tokenizes each word
// according to the BPE rules. Through encodingConfig one can state whether to add BOS, EOS tokens
// and whether to reverse the output sequences. EncodeSentence returns the numerical encoding
// of the sentence.
func (m Model) EncodeSentence(sentence string, encodingConfig EncodingConfig,
) (EncodedString, error) {
var encodedSentence EncodedString

if encodingConfig.bos {
if m.specialTokens.bos == -1 {
logrus.Error("Cannot use bos - model was trained without it")
return encodedSentence, errors.New("model was trained withous bos")
}
encodedSentence = append(encodedSentence, TokenID(m.specialTokens.bos))
}
for _, word := range strings.Fields(sentence) {
if len(word) == 0 {
continue
}
var encodedWord = []encodingToken{{m.spaceID, -1, 1}}
var pendingMerges mergeQueue
// Check whether two consecutive tokens can be merged and if so add merge suggestion to
// the priority queue
pushIfRuleExists := func(leftPos int) {
rightPos := encodedWord[leftPos].next
ruleCandidate := newTokenIDPair(encodedWord[leftPos].id, encodedWord[rightPos].id)
if priority, ok := m.rule2id[ruleCandidate]; ok {
heap.Push(&pendingMerges, &mergeEvent{priority, leftPos})
}
}
// Build linked list corresponding to the word's split on known chars and unknown tokens
unknownToken := false
for _, char := range word {
if charID, ok := m.char2id[char]; ok {
if unknownToken {
encodedWord = append(encodedWord,
encodingToken{TokenID(m.specialTokens.unk), len(encodedWord) - 1,
len(encodedWord) + 1})
unknownToken = false
}
encodedWord = append(encodedWord,
encodingToken{charID, len(encodedWord) - 1, len(encodedWord) + 1})
pushIfRuleExists(len(encodedWord) - 2)
} else {
unknownToken = true
}
}
if unknownToken {
encodedWord = append(encodedWord,
encodingToken{TokenID(m.specialTokens.unk), len(encodedWord) - 1,
len(encodedWord) + 1})
}
encodedWord[len(encodedWord)-1].next = -1
// Perform merges of subword tokens in the word according to the BPE model rules
for len(pendingMerges) > 0 {
event := heap.Pop(&pendingMerges).(*mergeEvent)
proposedRule := m.rules[event.priority]
leftPos := event.pos
leftToken := encodedWord[leftPos]
rightPos := leftToken.next
if rightPos == -1 {
continue
}
rightToken := encodedWord[rightPos]
// Check that the tokens suggested for the merge have not changed
if proposedRule.left != leftToken.id || proposedRule.right != rightToken.id {
continue
}
// Create token as a merge of the right and the left ones
leftToken.next = rightToken.next
leftToken.id = proposedRule.result
// Put merged token on the place of the left token
encodedWord[leftPos] = leftToken
// Put 'empty' token on the place of the right token
encodedWord[rightPos] = encodingToken{0, -1, -1}
// Add suggestions for merges for the new merged token
if rightToken.next != -1 {
encodedWord[rightToken.next].prev = leftPos
pushIfRuleExists(leftPos)
}
if leftToken.prev != -1 {
pushIfRuleExists(leftToken.prev)
}
}
// Retrieve all tokens that are left and append them to the result for the whole sentence
for pos := 0; pos > -1; {
encodedSentence = append(encodedSentence, encodedWord[pos].id)
pos = encodedWord[pos].next
}
}
if encodingConfig.eos {
if m.specialTokens.eos == -1 {
logrus.Error("Cannot use eos - model was trained without it")
return encodedSentence, errors.New("model was trained withous eos")
}
encodedSentence = append(encodedSentence, TokenID(m.specialTokens.eos))
}
if encodingConfig.reverse {
for i := 0; i < len(encodedSentence)/2; i++ {
encodedSentence[i], encodedSentence[len(encodedSentence)-i-1] =
encodedSentence[len(encodedSentence)-i-1], encodedSentence[i]
}
}
return encodedSentence, nil
}

// EncodeSentences takes a sequence of strings which consist of space-separated words and tokenizes
// each word according to the BPE rules. Through encodingConfig one can state whether to add BOS
// and EOS tokens (beginning and end of sentence) and whether to reverse the output sequences.
// EncodeSentences returns the numerical encodings of the sentences.
func (m Model) EncodeSentences(sentences []string, encodingConfig EncodingConfig) ([]EncodedString,
error) {
encodedSentence := make([]EncodedString, len(sentences))
for i, sentence := range sentences {
sentenceIds, err := m.EncodeSentence(sentence, encodingConfig)
if err != nil {
return encodedSentence, err
}
encodedSentence[i] = sentenceIds
}
return encodedSentence, nil
}

// EncodeStream reads a sequence of strings which consist of space-separated words from the given
// stream and tokenizes each word according to the BPE rules. Through encodingConfig one can state
// whether to add BOS and EOS tokens (beginning and end of sentence) and whether to reverse the
// output sequences. EncodeStream returns the numerical encodings of the sentences.
func (m Model) EncodeStream(reader io.Reader, encodingConfig EncodingConfig) ([]EncodedString,
error) {
scanner := bufio.NewScanner(reader)
var encodedSentence []EncodedString
for scanner.Scan() {
sentenceIds, err := m.EncodeSentence(scanner.Text(), encodingConfig)
if err != nil {
return encodedSentence, err
}
encodedSentence = append(encodedSentence, sentenceIds)
}
err := scanner.Err()
return encodedSentence, err
}
Loading