@@ -8,44 +8,50 @@ import (
88 "github.com/sirupsen/logrus"
99)
1010
11- type TokenId int32
11+ // TokenID is a numerical identitier of the subword token
12+ type TokenID uint32
1213
13- type DecodedToken []TokenId
14+ // EncodedToken is a sequence of subword tokens ids
15+ type EncodedToken []TokenID
1416
15- type Rule struct {
16- left TokenId
17- right TokenId
18- result TokenId
17+ type rule struct {
18+ left TokenID
19+ right TokenID
20+ result TokenID
1921}
2022
21- type SpecialTokens struct {
22- unk TokenId
23- pad TokenId
24- bos TokenId
25- eos TokenId
23+ type specialTokens struct {
24+ unk int
25+ pad int
26+ bos int
27+ eos int
2628}
2729
30+ // Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences
31+ // of most frequent subword tokens
2832type Model struct {
29- char2id map [rune ]TokenId
30- id2char map [TokenId ]rune
31- rules []Rule
32- recipe map [TokenId ] DecodedToken
33- revRecipe map [string ]TokenId
34- specialTokens SpecialTokens
33+ char2id map [rune ]TokenID
34+ id2char map [TokenID ]rune
35+ rules []rule
36+ recipe map [TokenID ] EncodedToken
37+ revRecipe map [string ]TokenID
38+ specialTokens specialTokens
3539}
3640
37- func NewModel (nRules int ) * Model {
41+ func newModel (nRules int ) * Model {
3842 return & Model {
39- make (map [rune ]TokenId ),
40- make (map [TokenId ]rune ),
41- make ([]Rule , nRules ),
42- make (map [TokenId ] DecodedToken ),
43- make (map [string ]TokenId ),
44- SpecialTokens {- 1 , - 1 , - 1 , - 1 },
43+ make (map [rune ]TokenID ),
44+ make (map [TokenID ]rune ),
45+ make ([]rule , nRules ),
46+ make (map [TokenID ] EncodedToken ),
47+ make (map [string ]TokenID ),
48+ specialTokens {- 1 , - 1 , - 1 , - 1 },
4549 }
4650}
4751
48- func DecodedTokenToString (token DecodedToken , id2char map [TokenId ]rune ) (string , error ) {
52+ // DecodeToken converts the sequence of chars' ids into the string -
53+ // sequence of the corresponding chars
54+ func DecodeToken (token EncodedToken , id2char map [TokenID ]rune ) (string , error ) {
4955 word := ""
5056 for _ , id := range token {
5157 if char , ok := id2char [id ]; ok {
@@ -57,7 +63,8 @@ func DecodedTokenToString(token DecodedToken, id2char map[TokenId]rune) (string,
5763 return word , nil
5864}
5965
60- func ReadModel (reader io.Reader ) (* Model , error ) {
66+ // ReadModelFromText loads the BPE model from the text dump
67+ func ReadModelFromText (reader io.Reader ) (* Model , error ) {
6168 scanner := bufio .NewScanner (reader )
6269 var nChars , nRules int
6370 scanner .Scan ()
@@ -66,24 +73,23 @@ func ReadModel(reader io.Reader) (*Model, error) {
6673 logrus .Fatal ("Wrong input format: " , err )
6774 return & Model {}, err
6875 }
69- model := NewModel (nRules )
70- model .rules = make ([]Rule , nRules )
76+ model := newModel (nRules )
7177 for i := 0 ; i < nChars ; i ++ {
7278 var char rune
73- var charId TokenId
79+ var charID TokenID
7480 scanner .Scan ()
75- _ , err = fmt .Sscanf (scanner .Text (), "%d %d" , & char , & charId )
81+ _ , err = fmt .Sscanf (scanner .Text (), "%d %d" , & char , & charID )
7682 if err != nil {
7783 logrus .Fatal ("Wrong input format: " , err )
7884 return model , err
7985 }
80- model .char2id [char ] = charId
81- model .id2char [charId ] = char
82- model .recipe [charId ] = DecodedToken { charId }
83- model .revRecipe [string (char )] = charId
86+ model .char2id [char ] = charID
87+ model .id2char [charID ] = char
88+ model .recipe [charID ] = EncodedToken { charID }
89+ model .revRecipe [string (char )] = charID
8490 }
8591 for i := 0 ; i < nRules ; i ++ {
86- var rule Rule
92+ var rule rule
8793 scanner .Scan ()
8894 _ , err = fmt .Sscanf (scanner .Text (), "%d %d %d" , & rule .left , & rule .right , & rule .result )
8995 if err != nil {
@@ -92,7 +98,7 @@ func ReadModel(reader io.Reader) (*Model, error) {
9298 }
9399 model .rules [i ] = rule
94100 model .recipe [rule .result ] = append (model .recipe [rule .left ], model .recipe [rule .right ]... )
95- resultString , err := DecodedTokenToString (model .recipe [rule .result ], model .id2char )
101+ resultString , err := DecodeToken (model .recipe [rule .result ], model .id2char )
96102 if err != nil {
97103 logrus .Fatal ("Unexpected token id inside the rules: " , err )
98104 return model , err
0 commit comments