Skip to content

Commit 2feb9cc

Browse files
authored
Merge pull request #7 from irinakhismatullina/encode
Encoding API
2 parents 54f64c4 + 7237c4c commit 2feb9cc

File tree

2 files changed

+315
-15
lines changed

2 files changed

+315
-15
lines changed

bpe.go

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package bpe
22

33
import (
44
"bufio"
5+
"container/heap"
56
"encoding/binary"
67
"errors"
78
"io"
@@ -14,6 +15,9 @@ import (
1415
// TokenID is a numerical identifier of the subword token
1516
type TokenID uint32
1617

18+
// TokenIDPair is a concatenation of two TokenIDs that is used as the key type in rule2id map.
19+
type TokenIDPair uint64
20+
1721
// EncodedString is a sequence of subword token identifiers
1822
type EncodedString []TokenID
1923

@@ -24,6 +28,13 @@ const (
2428
eosToken = "<EOS>"
2529
)
2630

31+
// EncodingConfig is a configuration for encoding of strings
32+
type EncodingConfig struct {
33+
bos bool
34+
eos bool
35+
reverse bool
36+
}
37+
2738
type rule struct {
2839
left TokenID
2940
right TokenID
@@ -43,6 +54,7 @@ type Model struct {
4354
char2id map[rune]TokenID
4455
id2char map[TokenID]rune
4556
rules []rule
57+
rule2id map[TokenIDPair]int
4658
recipe map[TokenID]EncodedString
4759
revRecipe map[string]TokenID
4860
specialTokens specialTokens
@@ -54,13 +66,18 @@ func newModel(nRules int) *Model {
5466
make(map[rune]TokenID),
5567
make(map[TokenID]rune),
5668
make([]rule, nRules),
69+
make(map[TokenIDPair]int),
5770
make(map[TokenID]EncodedString),
5871
make(map[string]TokenID),
5972
specialTokens{-1, -1, -1, -1},
6073
0,
6174
}
6275
}
6376

77+
func newTokenIDPair(left, right TokenID) TokenIDPair {
78+
return (TokenIDPair(left) << 32) + TokenIDPair(right)
79+
}
80+
6481
// DecodeToken converts the sequence of chars' ids into the string -
6582
// sequence of the corresponding chars
6683
func DecodeToken(token EncodedString, id2char map[TokenID]rune) (string, error) {
@@ -167,7 +184,6 @@ func ReadModel(reader io.Reader) (*Model, error) {
167184
if err != nil {
168185
return model, err
169186
}
170-
model.rules[i] = rule
171187
if _, ok := model.recipe[rule.left]; !ok {
172188
logrus.Errorf("%d: token id not described before", rule.left)
173189
return model, errors.New("token id is impossible")
@@ -176,6 +192,8 @@ func ReadModel(reader io.Reader) (*Model, error) {
176192
logrus.Errorf("%d: token id not described before", rule.right)
177193
return model, errors.New("token id is impossible")
178194
}
195+
model.rules[i] = rule
196+
model.rule2id[newTokenIDPair(rule.left, rule.right)] = i
179197
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
180198
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
181199
if err != nil {
@@ -194,6 +212,10 @@ func ReadModel(reader io.Reader) (*Model, error) {
194212
return model, err
195213
}
196214
model.specialTokens = specials
215+
model.revRecipe[bosToken] = TokenID(specials.bos)
216+
model.revRecipe[eosToken] = TokenID(specials.eos)
217+
model.revRecipe[unkToken] = TokenID(specials.unk)
218+
model.revRecipe[padToken] = TokenID(specials.pad)
197219
return model, err
198220
}
199221

@@ -287,3 +309,182 @@ func (m Model) DecodeFromStream(reader io.Reader) ([]string, error) {
287309
}
288310
return sentences, nil
289311
}
312+
313+
type encodingToken struct {
314+
id TokenID
315+
prev int
316+
next int
317+
}
318+
319+
type mergeEvent struct {
320+
priority int
321+
pos int
322+
}
323+
324+
type mergeQueue []*mergeEvent
325+
326+
func (mq mergeQueue) Len() int { return len(mq) }
327+
328+
func (mq mergeQueue) Less(i, j int) bool {
329+
return mq[i].priority < mq[j].priority ||
330+
mq[i].priority == mq[j].priority && mq[i].pos < mq[j].pos
331+
}
332+
333+
func (mq mergeQueue) Swap(i, j int) {
334+
mq[i], mq[j] = mq[j], mq[i]
335+
}
336+
337+
func (mq *mergeQueue) Push(x interface{}) {
338+
*mq = append(*mq, x.(*mergeEvent))
339+
}
340+
341+
func (mq *mergeQueue) Pop() interface{} {
342+
old := *mq
343+
n := len(old)
344+
item := old[n-1]
345+
old[n-1] = nil // avoid memory leak
346+
*mq = old[0 : n-1]
347+
return item
348+
}
349+
350+
// EncodeSentence takes a string of space-separated words and tokenizes each word
351+
// according to the BPE rules. Through encodingConfig one can state whether to add BOS, EOS tokens
352+
// and whether to reverse the output sequences. EncodeSentence returns the numerical encoding
353+
// of the sentence.
354+
func (m Model) EncodeSentence(sentence string, encodingConfig EncodingConfig,
355+
) (EncodedString, error) {
356+
var encodedSentence EncodedString
357+
358+
if encodingConfig.bos {
359+
if m.specialTokens.bos == -1 {
360+
logrus.Error("Cannot use bos - model was trained without it")
361+
return encodedSentence, errors.New("model was trained withous bos")
362+
}
363+
encodedSentence = append(encodedSentence, TokenID(m.specialTokens.bos))
364+
}
365+
for _, word := range strings.Fields(sentence) {
366+
if len(word) == 0 {
367+
continue
368+
}
369+
var encodedWord = []encodingToken{{m.spaceID, -1, 1}}
370+
var pendingMerges mergeQueue
371+
// Check whether two consecutive tokens can be merged and if so add merge suggestion to
372+
// the priority queue
373+
pushIfRuleExists := func(leftPos int) {
374+
rightPos := encodedWord[leftPos].next
375+
ruleCandidate := newTokenIDPair(encodedWord[leftPos].id, encodedWord[rightPos].id)
376+
if priority, ok := m.rule2id[ruleCandidate]; ok {
377+
heap.Push(&pendingMerges, &mergeEvent{priority, leftPos})
378+
}
379+
}
380+
// Build linked list corresponding to the word's split on known chars and unknown tokens
381+
unknownToken := false
382+
for _, char := range word {
383+
if charID, ok := m.char2id[char]; ok {
384+
if unknownToken {
385+
encodedWord = append(encodedWord,
386+
encodingToken{TokenID(m.specialTokens.unk), len(encodedWord) - 1,
387+
len(encodedWord) + 1})
388+
unknownToken = false
389+
}
390+
encodedWord = append(encodedWord,
391+
encodingToken{charID, len(encodedWord) - 1, len(encodedWord) + 1})
392+
pushIfRuleExists(len(encodedWord) - 2)
393+
} else {
394+
unknownToken = true
395+
}
396+
}
397+
if unknownToken {
398+
encodedWord = append(encodedWord,
399+
encodingToken{TokenID(m.specialTokens.unk), len(encodedWord) - 1,
400+
len(encodedWord) + 1})
401+
}
402+
encodedWord[len(encodedWord)-1].next = -1
403+
// Perform merges of subword tokens in the word according to the BPE model rules
404+
for len(pendingMerges) > 0 {
405+
event := heap.Pop(&pendingMerges).(*mergeEvent)
406+
proposedRule := m.rules[event.priority]
407+
leftPos := event.pos
408+
leftToken := encodedWord[leftPos]
409+
rightPos := leftToken.next
410+
if rightPos == -1 {
411+
continue
412+
}
413+
rightToken := encodedWord[rightPos]
414+
// Check that the tokens suggested for the merge have not changed
415+
if proposedRule.left != leftToken.id || proposedRule.right != rightToken.id {
416+
continue
417+
}
418+
// Create token as a merge of the right and the left ones
419+
leftToken.next = rightToken.next
420+
leftToken.id = proposedRule.result
421+
// Put merged token on the place of the left token
422+
encodedWord[leftPos] = leftToken
423+
// Put 'empty' token on the place of the right token
424+
encodedWord[rightPos] = encodingToken{0, -1, -1}
425+
// Add suggestions for merges for the new merged token
426+
if rightToken.next != -1 {
427+
encodedWord[rightToken.next].prev = leftPos
428+
pushIfRuleExists(leftPos)
429+
}
430+
if leftToken.prev != -1 {
431+
pushIfRuleExists(leftToken.prev)
432+
}
433+
}
434+
// Retrieve all tokens that are left and append them to the result for the whole sentence
435+
for pos := 0; pos > -1; {
436+
encodedSentence = append(encodedSentence, encodedWord[pos].id)
437+
pos = encodedWord[pos].next
438+
}
439+
}
440+
if encodingConfig.eos {
441+
if m.specialTokens.eos == -1 {
442+
logrus.Error("Cannot use eos - model was trained without it")
443+
return encodedSentence, errors.New("model was trained withous eos")
444+
}
445+
encodedSentence = append(encodedSentence, TokenID(m.specialTokens.eos))
446+
}
447+
if encodingConfig.reverse {
448+
for i := 0; i < len(encodedSentence)/2; i++ {
449+
encodedSentence[i], encodedSentence[len(encodedSentence)-i-1] =
450+
encodedSentence[len(encodedSentence)-i-1], encodedSentence[i]
451+
}
452+
}
453+
return encodedSentence, nil
454+
}
455+
456+
// EncodeSentences takes a sequence of strings which consist of space-separated words and tokenizes
457+
// each word according to the BPE rules. Through encodingConfig one can state whether to add BOS
458+
// and EOS tokens (beginning and end of sentence) and whether to reverse the output sequences.
459+
// EncodeSentences returns the numerical encodings of the sentences.
460+
func (m Model) EncodeSentences(sentences []string, encodingConfig EncodingConfig) ([]EncodedString,
461+
error) {
462+
encodedSentence := make([]EncodedString, len(sentences))
463+
for i, sentence := range sentences {
464+
sentenceIds, err := m.EncodeSentence(sentence, encodingConfig)
465+
if err != nil {
466+
return encodedSentence, err
467+
}
468+
encodedSentence[i] = sentenceIds
469+
}
470+
return encodedSentence, nil
471+
}
472+
473+
// EncodeStream reads a sequence of strings which consist of space-separated words from the given
474+
// stream and tokenizes each word according to the BPE rules. Through encodingConfig one can state
475+
// whether to add BOS and EOS tokens (beginning and end of sentence) and whether to reverse the
476+
// output sequences. EncodeStream returns the numerical encodings of the sentences.
477+
func (m Model) EncodeStream(reader io.Reader, encodingConfig EncodingConfig) ([]EncodedString,
478+
error) {
479+
scanner := bufio.NewScanner(reader)
480+
var encodedSentence []EncodedString
481+
for scanner.Scan() {
482+
sentenceIds, err := m.EncodeSentence(scanner.Text(), encodingConfig)
483+
if err != nil {
484+
return encodedSentence, err
485+
}
486+
encodedSentence = append(encodedSentence, sentenceIds)
487+
}
488+
err := scanner.Err()
489+
return encodedSentence, err
490+
}

0 commit comments

Comments
 (0)