@@ -2,6 +2,7 @@ package bpe
22
33import (
44 "bufio"
5+ "container/heap"
56 "encoding/binary"
67 "errors"
78 "io"
@@ -14,6 +15,9 @@ import (
1415// TokenID is a numerical identifier of the subword token
1516type TokenID uint32
1617
18+ // TokenIDPair is a concatenation of two TokenIDs that is used as the key type in rule2id map.
19+ type TokenIDPair uint64
20+
1721// EncodedString is a sequence of subword token identifiers
1822type EncodedString []TokenID
1923
@@ -24,6 +28,13 @@ const (
2428 eosToken = "<EOS>"
2529)
2630
31+ // EncodingConfig is a configuration for encoding of strings
32+ type EncodingConfig struct {
33+ bos bool
34+ eos bool
35+ reverse bool
36+ }
37+
2738type rule struct {
2839 left TokenID
2940 right TokenID
@@ -43,6 +54,7 @@ type Model struct {
4354 char2id map [rune ]TokenID
4455 id2char map [TokenID ]rune
4556 rules []rule
57+ rule2id map [TokenIDPair ]int
4658 recipe map [TokenID ]EncodedString
4759 revRecipe map [string ]TokenID
4860 specialTokens specialTokens
@@ -54,13 +66,18 @@ func newModel(nRules int) *Model {
5466 make (map [rune ]TokenID ),
5567 make (map [TokenID ]rune ),
5668 make ([]rule , nRules ),
69+ make (map [TokenIDPair ]int ),
5770 make (map [TokenID ]EncodedString ),
5871 make (map [string ]TokenID ),
5972 specialTokens {- 1 , - 1 , - 1 , - 1 },
6073 0 ,
6174 }
6275}
6376
77+ func newTokenIDPair (left , right TokenID ) TokenIDPair {
78+ return (TokenIDPair (left ) << 32 ) + TokenIDPair (right )
79+ }
80+
6481// DecodeToken converts the sequence of chars' ids into the string -
6582// sequence of the corresponding chars
6683func DecodeToken (token EncodedString , id2char map [TokenID ]rune ) (string , error ) {
@@ -167,7 +184,6 @@ func ReadModel(reader io.Reader) (*Model, error) {
167184 if err != nil {
168185 return model , err
169186 }
170- model .rules [i ] = rule
171187 if _ , ok := model .recipe [rule .left ]; ! ok {
172188 logrus .Errorf ("%d: token id not described before" , rule .left )
173189 return model , errors .New ("token id is impossible" )
@@ -176,6 +192,8 @@ func ReadModel(reader io.Reader) (*Model, error) {
176192 logrus .Errorf ("%d: token id not described before" , rule .right )
177193 return model , errors .New ("token id is impossible" )
178194 }
195+ model .rules [i ] = rule
196+ model .rule2id [newTokenIDPair (rule .left , rule .right )] = i
179197 model .recipe [rule .result ] = append (model .recipe [rule .left ], model .recipe [rule .right ]... )
180198 resultString , err := DecodeToken (model .recipe [rule .result ], model .id2char )
181199 if err != nil {
@@ -194,6 +212,10 @@ func ReadModel(reader io.Reader) (*Model, error) {
194212 return model , err
195213 }
196214 model .specialTokens = specials
215+ model .revRecipe [bosToken ] = TokenID (specials .bos )
216+ model .revRecipe [eosToken ] = TokenID (specials .eos )
217+ model .revRecipe [unkToken ] = TokenID (specials .unk )
218+ model .revRecipe [padToken ] = TokenID (specials .pad )
197219 return model , err
198220}
199221
@@ -287,3 +309,182 @@ func (m Model) DecodeFromStream(reader io.Reader) ([]string, error) {
287309 }
288310 return sentences , nil
289311}
312+
313+ type encodingToken struct {
314+ id TokenID
315+ prev int
316+ next int
317+ }
318+
319+ type mergeEvent struct {
320+ priority int
321+ pos int
322+ }
323+
324+ type mergeQueue []* mergeEvent
325+
326+ func (mq mergeQueue ) Len () int { return len (mq ) }
327+
328+ func (mq mergeQueue ) Less (i , j int ) bool {
329+ return mq [i ].priority < mq [j ].priority ||
330+ mq [i ].priority == mq [j ].priority && mq [i ].pos < mq [j ].pos
331+ }
332+
333+ func (mq mergeQueue ) Swap (i , j int ) {
334+ mq [i ], mq [j ] = mq [j ], mq [i ]
335+ }
336+
337+ func (mq * mergeQueue ) Push (x interface {}) {
338+ * mq = append (* mq , x .(* mergeEvent ))
339+ }
340+
341+ func (mq * mergeQueue ) Pop () interface {} {
342+ old := * mq
343+ n := len (old )
344+ item := old [n - 1 ]
345+ old [n - 1 ] = nil // avoid memory leak
346+ * mq = old [0 : n - 1 ]
347+ return item
348+ }
349+
350+ // EncodeSentence takes a string of space-separated words and tokenizes each word
351+ // according to the BPE rules. Through encodingConfig one can state whether to add BOS, EOS tokens
352+ // and whether to reverse the output sequences. EncodeSentence returns the numerical encoding
353+ // of the sentence.
354+ func (m Model ) EncodeSentence (sentence string , encodingConfig EncodingConfig ,
355+ ) (EncodedString , error ) {
356+ var encodedSentence EncodedString
357+
358+ if encodingConfig .bos {
359+ if m .specialTokens .bos == - 1 {
360+ logrus .Error ("Cannot use bos - model was trained without it" )
361+ return encodedSentence , errors .New ("model was trained withous bos" )
362+ }
363+ encodedSentence = append (encodedSentence , TokenID (m .specialTokens .bos ))
364+ }
365+ for _ , word := range strings .Fields (sentence ) {
366+ if len (word ) == 0 {
367+ continue
368+ }
369+ var encodedWord = []encodingToken {{m .spaceID , - 1 , 1 }}
370+ var pendingMerges mergeQueue
371+ // Check whether two consecutive tokens can be merged and if so add merge suggestion to
372+ // the priority queue
373+ pushIfRuleExists := func (leftPos int ) {
374+ rightPos := encodedWord [leftPos ].next
375+ ruleCandidate := newTokenIDPair (encodedWord [leftPos ].id , encodedWord [rightPos ].id )
376+ if priority , ok := m .rule2id [ruleCandidate ]; ok {
377+ heap .Push (& pendingMerges , & mergeEvent {priority , leftPos })
378+ }
379+ }
380+ // Build linked list corresponding to the word's split on known chars and unknown tokens
381+ unknownToken := false
382+ for _ , char := range word {
383+ if charID , ok := m .char2id [char ]; ok {
384+ if unknownToken {
385+ encodedWord = append (encodedWord ,
386+ encodingToken {TokenID (m .specialTokens .unk ), len (encodedWord ) - 1 ,
387+ len (encodedWord ) + 1 })
388+ unknownToken = false
389+ }
390+ encodedWord = append (encodedWord ,
391+ encodingToken {charID , len (encodedWord ) - 1 , len (encodedWord ) + 1 })
392+ pushIfRuleExists (len (encodedWord ) - 2 )
393+ } else {
394+ unknownToken = true
395+ }
396+ }
397+ if unknownToken {
398+ encodedWord = append (encodedWord ,
399+ encodingToken {TokenID (m .specialTokens .unk ), len (encodedWord ) - 1 ,
400+ len (encodedWord ) + 1 })
401+ }
402+ encodedWord [len (encodedWord )- 1 ].next = - 1
403+ // Perform merges of subword tokens in the word according to the BPE model rules
404+ for len (pendingMerges ) > 0 {
405+ event := heap .Pop (& pendingMerges ).(* mergeEvent )
406+ proposedRule := m .rules [event .priority ]
407+ leftPos := event .pos
408+ leftToken := encodedWord [leftPos ]
409+ rightPos := leftToken .next
410+ if rightPos == - 1 {
411+ continue
412+ }
413+ rightToken := encodedWord [rightPos ]
414+ // Check that the tokens suggested for the merge have not changed
415+ if proposedRule .left != leftToken .id || proposedRule .right != rightToken .id {
416+ continue
417+ }
418+ // Create token as a merge of the right and the left ones
419+ leftToken .next = rightToken .next
420+ leftToken .id = proposedRule .result
421+ // Put merged token on the place of the left token
422+ encodedWord [leftPos ] = leftToken
423+ // Put 'empty' token on the place of the right token
424+ encodedWord [rightPos ] = encodingToken {0 , - 1 , - 1 }
425+ // Add suggestions for merges for the new merged token
426+ if rightToken .next != - 1 {
427+ encodedWord [rightToken .next ].prev = leftPos
428+ pushIfRuleExists (leftPos )
429+ }
430+ if leftToken .prev != - 1 {
431+ pushIfRuleExists (leftToken .prev )
432+ }
433+ }
434+ // Retrieve all tokens that are left and append them to the result for the whole sentence
435+ for pos := 0 ; pos > - 1 ; {
436+ encodedSentence = append (encodedSentence , encodedWord [pos ].id )
437+ pos = encodedWord [pos ].next
438+ }
439+ }
440+ if encodingConfig .eos {
441+ if m .specialTokens .eos == - 1 {
442+ logrus .Error ("Cannot use eos - model was trained without it" )
443+ return encodedSentence , errors .New ("model was trained withous eos" )
444+ }
445+ encodedSentence = append (encodedSentence , TokenID (m .specialTokens .eos ))
446+ }
447+ if encodingConfig .reverse {
448+ for i := 0 ; i < len (encodedSentence )/ 2 ; i ++ {
449+ encodedSentence [i ], encodedSentence [len (encodedSentence )- i - 1 ] =
450+ encodedSentence [len (encodedSentence )- i - 1 ], encodedSentence [i ]
451+ }
452+ }
453+ return encodedSentence , nil
454+ }
455+
456+ // EncodeSentences takes a sequence of strings which consist of space-separated words and tokenizes
457+ // each word according to the BPE rules. Through encodingConfig one can state whether to add BOS
458+ // and EOS tokens (beginning and end of sentence) and whether to reverse the output sequences.
459+ // EncodeSentences returns the numerical encodings of the sentences.
460+ func (m Model ) EncodeSentences (sentences []string , encodingConfig EncodingConfig ) ([]EncodedString ,
461+ error ) {
462+ encodedSentence := make ([]EncodedString , len (sentences ))
463+ for i , sentence := range sentences {
464+ sentenceIds , err := m .EncodeSentence (sentence , encodingConfig )
465+ if err != nil {
466+ return encodedSentence , err
467+ }
468+ encodedSentence [i ] = sentenceIds
469+ }
470+ return encodedSentence , nil
471+ }
472+
473+ // EncodeStream reads a sequence of strings which consist of space-separated words from the given
474+ // stream and tokenizes each word according to the BPE rules. Through encodingConfig one can state
475+ // whether to add BOS and EOS tokens (beginning and end of sentence) and whether to reverse the
476+ // output sequences. EncodeStream returns the numerical encodings of the sentences.
477+ func (m Model ) EncodeStream (reader io.Reader , encodingConfig EncodingConfig ) ([]EncodedString ,
478+ error ) {
479+ scanner := bufio .NewScanner (reader )
480+ var encodedSentence []EncodedString
481+ for scanner .Scan () {
482+ sentenceIds , err := m .EncodeSentence (scanner .Text (), encodingConfig )
483+ if err != nil {
484+ return encodedSentence , err
485+ }
486+ encodedSentence = append (encodedSentence , sentenceIds )
487+ }
488+ err := scanner .Err ()
489+ return encodedSentence , err
490+ }
0 commit comments