Skip to content

Commit 4bbb95e

Browse files
Expand tests
Signed-off-by: Irina Khismatullina <irenekhismatullina@gmail.com>
1 parent 9bbc64d commit 4bbb95e

4 files changed

Lines changed: 141 additions & 48 deletions

File tree

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ install:
1616
script:
1717
- make install-dev-deps
1818
- make check-style
19-
- make test
19+
- make test-coverage
20+
- make codecov
2021

2122
matrix:
2223
fast_finish: true

bpe.go

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package bpe
22

33
import (
4-
"bufio"
54
"encoding/binary"
5+
"errors"
66
"io"
77

88
"github.com/sirupsen/logrus"
@@ -57,60 +57,68 @@ func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
5757
if char, ok := id2char[id]; ok {
5858
word = word + string(char)
5959
} else {
60-
logrus.Fatalf("%d key not found in id2char", id)
60+
logrus.Errorf("Decode failure: %d token id has no corresponding char", id)
61+
return "", errors.New("key not found in id2char")
6162
}
6263
}
6364
return word, nil
6465
}
6566

66-
func specialTokensToBin(specials specialTokens) []byte {
67+
func (s specialTokens) toBinary() []byte {
6768
bytesArray := make([]byte, 16)
68-
binary.BigEndian.PutUint32(bytesArray, uint32(specials.unk))
69-
binary.BigEndian.PutUint32(bytesArray[4:], uint32(specials.pad))
70-
binary.BigEndian.PutUint32(bytesArray[8:], uint32(specials.bos))
71-
binary.BigEndian.PutUint32(bytesArray[12:], uint32(specials.eos))
69+
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk))
70+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad))
71+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos))
72+
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos))
7273
return bytesArray
7374
}
7475

75-
func binToSpecialTokens(bytesArray []byte) specialTokens {
76+
func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
7677
var s specialTokens
78+
if len(bytesArray) < 16 {
79+
logrus.Error("Bytes array length is too small")
80+
return s, errors.New("bytes array is too small")
81+
}
7782
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
7883
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
7984
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
8085
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
81-
return s
86+
return s, nil
8287
}
8388

84-
func ruleToBin(rule rule) []byte {
89+
func (r rule) toBinary() []byte {
8590
bytesArray := make([]byte, 12)
86-
binary.BigEndian.PutUint32(bytesArray, uint32(rule.left))
87-
binary.BigEndian.PutUint32(bytesArray[4:], uint32(rule.right))
88-
binary.BigEndian.PutUint32(bytesArray[8:], uint32(rule.result))
91+
binary.BigEndian.PutUint32(bytesArray, uint32(r.left))
92+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right))
93+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result))
8994
return bytesArray
9095
}
9196

92-
func binToRule(bytesArray []byte) rule {
97+
func binaryToRule(bytesArray []byte) (rule, error) {
9398
var r rule
99+
if len(bytesArray) < 12 {
100+
logrus.Error("Bytes array length is too small")
101+
return r, errors.New("bytes array is too small")
102+
}
94103
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
95104
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
96105
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))
97-
return r
106+
return r, nil
98107
}
99108

100-
// ReadModelFromBinary loads the BPE model from the binary dump
101-
func ReadModelFromBinary(reader io.Reader) (*Model, error) {
102-
bytesReader := bufio.NewReader(reader)
109+
// ReadModel loads the BPE model from the binary dump
110+
func ReadModel(reader io.Reader) (*Model, error) {
103111
buf := make([]byte, 4)
104112
var nChars, nRules int
105-
_, err := bytesReader.Read(buf)
113+
_, err := io.ReadFull(reader, buf)
106114
if err != nil {
107-
logrus.Fatal("Broken input: ", err)
115+
logrus.Error("Broken input: ", err)
108116
return &Model{}, err
109117
}
110118
nChars = int(binary.BigEndian.Uint32(buf))
111-
_, err = bytesReader.Read(buf)
119+
_, err = io.ReadFull(reader, buf)
112120
if err != nil {
113-
logrus.Fatal("Broken input: ", err)
121+
logrus.Error("Broken input: ", err)
114122
return &Model{}, err
115123
}
116124
nRules = int(binary.BigEndian.Uint32(buf))
@@ -119,15 +127,15 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
119127
for i := 0; i < nChars; i++ {
120128
var char rune
121129
var charID TokenID
122-
_, err = bytesReader.Read(buf)
130+
_, err = io.ReadFull(reader, buf)
123131
if err != nil {
124-
logrus.Fatal("Broken input: ", err)
132+
logrus.Error("Broken input: ", err)
125133
return &Model{}, err
126134
}
127135
char = rune(binary.BigEndian.Uint32(buf))
128-
_, err = bytesReader.Read(buf)
136+
_, err = io.ReadFull(reader, buf)
129137
if err != nil {
130-
logrus.Fatal("Broken input: ", err)
138+
logrus.Error("Broken input: ", err)
131139
return &Model{}, err
132140
}
133141
charID = TokenID(binary.BigEndian.Uint32(buf))
@@ -138,27 +146,38 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
138146
}
139147
ruleBuf := make([]byte, 12)
140148
for i := 0; i < nRules; i++ {
141-
_, err = bytesReader.Read(ruleBuf)
149+
_, err = io.ReadFull(reader, ruleBuf)
142150
if err != nil {
143-
logrus.Fatal("Broken input: ", err)
151+
logrus.Error("Broken input: ", err)
144152
return &Model{}, err
145153
}
146-
rule := binToRule(ruleBuf)
154+
rule, err := binaryToRule(ruleBuf)
155+
if err != nil {
156+
return model, err
157+
}
147158
model.rules[i] = rule
159+
if _, ok := model.recipe[rule.left]; !ok {
160+
logrus.Errorf("%d: token id not described before", rule.left)
161+
return model, errors.New("key not found in id2char")
162+
}
163+
if _, ok := model.recipe[rule.right]; !ok {
164+
logrus.Errorf("%d: token id not described before", rule.right)
165+
return model, errors.New("key not found in id2char")
166+
}
148167
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
149168
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
150169
if err != nil {
151-
logrus.Fatal("Unexpected token id inside the rules: ", err)
170+
logrus.Error("Unexpected token id inside the rules: ", err)
152171
return model, err
153172
}
154173
model.revRecipe[resultString] = rule.result
155174
}
156175
specialTokensBuf := make([]byte, 16)
157-
_, err = bytesReader.Read(specialTokensBuf)
176+
_, err = io.ReadFull(reader, specialTokensBuf)
158177
if err != nil {
159-
logrus.Fatal("Broken input: ", err)
178+
logrus.Error("Broken input: ", err)
160179
return &Model{}, err
161180
}
162-
model.specialTokens = binToSpecialTokens(specialTokensBuf)
163-
return model, nil
181+
model.specialTokens, err = binaryToSpecialTokens(specialTokensBuf)
182+
return model, err
164183
}

bpe_test.go

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,54 @@ func TestNewModel(t *testing.T) {
1212
require.Equal(t, 10, len(model.rules))
1313
}
1414

15-
func TestDecodedTokenToString(t *testing.T) {
15+
func TestDecodeToken(t *testing.T) {
1616
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
1717
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
1818
require.NoError(t, err)
1919
require.Equal(t, "abacc", word)
2020
}
2121

22-
func TestSpecialTokensToBin(t *testing.T) {
22+
func TestSpecialTokensToBinary(t *testing.T) {
2323
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
2424
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
25-
require.Equal(t, bytesArray, specialTokensToBin(specials))
25+
require.Equal(t, bytesArray, specials.toBinary())
2626
}
2727

28-
func TestBinToSpecialTokens(t *testing.T) {
28+
func TestBinaryToSpecialTokens(t *testing.T) {
2929
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
30-
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
31-
require.Equal(t, specials, binToSpecialTokens(bytesArray))
30+
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
31+
specials, err := binaryToSpecialTokens(bytesArray)
32+
require.NoError(t, err)
33+
require.Equal(t, expected, specials)
34+
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
35+
specials, err = binaryToSpecialTokens(bytesArray)
36+
require.Error(t, err)
37+
bytesArray = []byte{}
38+
specials, err = binaryToSpecialTokens(bytesArray)
39+
require.Error(t, err)
3240
}
3341

34-
func TestRuleToBin(t *testing.T) {
42+
func TestRuleToBinary(t *testing.T) {
3543
rule := rule{1, 2, 257}
3644
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
37-
require.Equal(t, bytesArray, ruleToBin(rule))
45+
require.Equal(t, bytesArray, rule.toBinary())
3846
}
3947

40-
func TestBinToRule(t *testing.T) {
41-
rule := rule{1, 2, 257}
48+
func TestBinaryToRule(t *testing.T) {
49+
expected := rule{1, 2, 257}
4250
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
43-
require.Equal(t, rule, binToRule(bytesArray))
51+
rule, err := binaryToRule(bytesArray)
52+
require.NoError(t, err)
53+
require.Equal(t, expected, rule)
54+
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
55+
rule, err = binaryToRule(bytesArray)
56+
require.Error(t, err)
57+
bytesArray = []byte{}
58+
rule, err = binaryToRule(bytesArray)
59+
require.Error(t, err)
4460
}
4561

46-
func TestReadModelFromBinary(t *testing.T) {
62+
func TestReadModel(t *testing.T) {
4763
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
4864
0, 0, 0, 99, 0, 0, 0, 6,
4965
0, 0, 0, 98, 0, 0, 0, 7,
@@ -64,7 +80,52 @@ func TestReadModelFromBinary(t *testing.T) {
6480
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
6581
specialTokens{1, 0, 2, 3},
6682
}
67-
model, err := ReadModelFromBinary(reader)
83+
model, err := ReadModel(reader)
6884
require.NoError(t, err)
6985
require.Equal(t, expected, *model)
86+
87+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
88+
0, 0, 0, 99, 0, 0, 0, 6,
89+
0, 0, 0, 98, 0, 0, 0, 7,
90+
0, 0, 0, 95, 0, 0, 0, 4,
91+
0, 0, 0, 100, 0, 0, 0, 5,
92+
0, 0, 0, 97, 0, 0, 0, 8,
93+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
94+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
95+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
96+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
97+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3,
98+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
99+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12})
100+
model, err = ReadModel(reader)
101+
require.NoError(t, err)
102+
require.Equal(t, expected, *model)
103+
104+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
105+
0, 0, 0, 99, 0, 0, 0, 6,
106+
0, 0, 0, 98, 0, 0, 0, 7,
107+
0, 0, 0, 95, 0, 0, 0, 4,
108+
0, 0, 0, 100, 0, 0, 0, 5,
109+
0, 0, 0, 97, 0, 0, 0, 8,
110+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
111+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
112+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
113+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
114+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0})
115+
model, err = ReadModel(reader)
116+
require.Error(t, err)
117+
118+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
119+
0, 0, 0, 99, 0, 0, 0, 6,
120+
0, 0, 0, 98, 0, 0, 0, 7,
121+
0, 0, 0, 95, 0, 0, 0, 4,
122+
0, 0, 0, 100, 0, 0, 0, 5,
123+
0, 0, 0, 97, 0, 0, 0, 8,
124+
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9,
125+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
126+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
127+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
128+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
129+
model, err = ReadModel(reader)
130+
require.Error(t, err)
70131
}

go.sum

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
4+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5+
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
6+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
7+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
8+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
9+
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
10+
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
11+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
12+
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

0 commit comments

Comments
 (0)