Skip to content

Commit 62d0680

Browse files
Fix style
Signed-off-by: Irina Khismatullina <irenekhismatullina@gmail.com>
1 parent fb6673d commit 62d0680

4 files changed

Lines changed: 54 additions & 75 deletions

File tree

bpe.go

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,50 @@ import (
88
"github.com/sirupsen/logrus"
99
)
1010

11-
type TokenId int32
11+
// TokenID is a numerical identitier of the subword token
12+
type TokenID uint32
1213

13-
type DecodedToken []TokenId
14+
// EncodedToken is a sequence of subword tokens ids
15+
type EncodedToken []TokenID
1416

15-
type Rule struct {
16-
left TokenId
17-
right TokenId
18-
result TokenId
17+
type rule struct {
18+
left TokenID
19+
right TokenID
20+
result TokenID
1921
}
2022

21-
type SpecialTokens struct {
22-
unk TokenId
23-
pad TokenId
24-
bos TokenId
25-
eos TokenId
23+
type specialTokens struct {
24+
unk int
25+
pad int
26+
bos int
27+
eos int
2628
}
2729

30+
// Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences
31+
// of most frequent subword tokens
2832
type Model struct {
29-
char2id map[rune]TokenId
30-
id2char map[TokenId]rune
31-
rules []Rule
32-
recipe map[TokenId]DecodedToken
33-
revRecipe map[string]TokenId
34-
specialTokens SpecialTokens
33+
char2id map[rune]TokenID
34+
id2char map[TokenID]rune
35+
rules []rule
36+
recipe map[TokenID]EncodedToken
37+
revRecipe map[string]TokenID
38+
specialTokens specialTokens
3539
}
3640

37-
func NewModel(nRules int) *Model {
41+
func newModel(nRules int) *Model {
3842
return &Model{
39-
make(map[rune]TokenId),
40-
make(map[TokenId]rune),
41-
make([]Rule, nRules),
42-
make(map[TokenId]DecodedToken),
43-
make(map[string]TokenId),
44-
SpecialTokens{-1, -1, -1, -1},
43+
make(map[rune]TokenID),
44+
make(map[TokenID]rune),
45+
make([]rule, nRules),
46+
make(map[TokenID]EncodedToken),
47+
make(map[string]TokenID),
48+
specialTokens{-1, -1, -1, -1},
4549
}
4650
}
4751

48-
func DecodedTokenToString(token DecodedToken, id2char map[TokenId]rune) (string, error) {
52+
// DecodeToken converts the sequence of chars' ids into the string -
53+
// sequence of the corresponding chars
54+
func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
4955
word := ""
5056
for _, id := range token {
5157
if char, ok := id2char[id]; ok {
@@ -57,7 +63,8 @@ func DecodedTokenToString(token DecodedToken, id2char map[TokenId]rune) (string,
5763
return word, nil
5864
}
5965

60-
func ReadModel(reader io.Reader) (*Model, error) {
66+
// ReadModelFromText loads the BPE model from the text dump
67+
func ReadModelFromText(reader io.Reader) (*Model, error) {
6168
scanner := bufio.NewScanner(reader)
6269
var nChars, nRules int
6370
scanner.Scan()
@@ -66,24 +73,23 @@ func ReadModel(reader io.Reader) (*Model, error) {
6673
logrus.Fatal("Wrong input format: ", err)
6774
return &Model{}, err
6875
}
69-
model := NewModel(nRules)
70-
model.rules = make([]Rule, nRules)
76+
model := newModel(nRules)
7177
for i := 0; i < nChars; i++ {
7278
var char rune
73-
var charId TokenId
79+
var charID TokenID
7480
scanner.Scan()
75-
_, err = fmt.Sscanf(scanner.Text(), "%d %d", &char, &charId)
81+
_, err = fmt.Sscanf(scanner.Text(), "%d %d", &char, &charID)
7682
if err != nil {
7783
logrus.Fatal("Wrong input format: ", err)
7884
return model, err
7985
}
80-
model.char2id[char] = charId
81-
model.id2char[charId] = char
82-
model.recipe[charId] = DecodedToken{charId}
83-
model.revRecipe[string(char)] = charId
86+
model.char2id[char] = charID
87+
model.id2char[charID] = char
88+
model.recipe[charID] = EncodedToken{charID}
89+
model.revRecipe[string(char)] = charID
8490
}
8591
for i := 0; i < nRules; i++ {
86-
var rule Rule
92+
var rule rule
8793
scanner.Scan()
8894
_, err = fmt.Sscanf(scanner.Text(), "%d %d %d", &rule.left, &rule.right, &rule.result)
8995
if err != nil {
@@ -92,7 +98,7 @@ func ReadModel(reader io.Reader) (*Model, error) {
9298
}
9399
model.rules[i] = rule
94100
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
95-
resultString, err := DecodedTokenToString(model.recipe[rule.result], model.id2char)
101+
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
96102
if err != nil {
97103
logrus.Fatal("Unexpected token id inside the rules: ", err)
98104
return model, err

bpe_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
package bpe
22

33
import (
4-
"github.com/stretchr/testify/require"
54
"strings"
65
"testing"
6+
7+
"github.com/stretchr/testify/require"
78
)
89

910
func TestNewModel(t *testing.T) {
10-
model := NewModel(10)
11+
model := newModel(10)
1112
require.Equal(t, 10, len(model.rules))
1213
}
1314

1415
func TestDecodedTokenToString(t *testing.T) {
15-
id2char := map[TokenId]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
16-
word, err := DecodedTokenToString(DecodedToken{1, 2, 1, 3, 3}, id2char)
16+
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
17+
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
1718
require.NoError(t, err)
1819
require.Equal(t, "abacc", word)
1920
}
@@ -31,14 +32,14 @@ func TestReadModel(t *testing.T) {
3132
4 7 12
3233
1 0 2 4`)
3334
expected := Model{
34-
map[rune]TokenId{97: 8, 98: 7, 99: 6, 100: 5, 95: 4},
35-
map[TokenId]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97},
36-
[]Rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}},
37-
map[TokenId]DecodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}},
38-
map[string]TokenId{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4,
35+
map[rune]TokenID{97: 8, 98: 7, 99: 6, 100: 5, 95: 4},
36+
map[TokenID]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97},
37+
[]rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}},
38+
map[TokenID]EncodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}},
39+
map[string]TokenID{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4,
3940
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
40-
SpecialTokens{1, 0, 2, 4},
41+
specialTokens{1, 0, 2, 4},
4142
}
42-
model, _ := ReadModel(reader)
43+
model, _ := ReadModelFromText(reader)
4344
require.Equal(t, expected, *model)
4445
}

go.sum

Lines changed: 0 additions & 21 deletions
This file was deleted.

main.go

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)