Skip to content

Commit 5a79713

Browse files
Expand tests
Signed-off-by: Irina Khismatullina <irenekhismatullina@gmail.com>
1 parent 9bbc64d commit 5a79713

3 files changed

Lines changed: 130 additions & 54 deletions

File tree

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ install:
1616
script:
1717
- make install-dev-deps
1818
- make check-style
19-
- make test
19+
- make test-coverage
20+
- make codecov
2021

2122
matrix:
2223
fast_finish: true

bpe.go

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package bpe
22

33
import (
4-
"bufio"
54
"encoding/binary"
5+
"errors"
66
"io"
77

88
"github.com/sirupsen/logrus"
@@ -57,60 +57,66 @@ func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
5757
if char, ok := id2char[id]; ok {
5858
word = word + string(char)
5959
} else {
60-
logrus.Fatalf("%d key not found in id2char", id)
60+
logrus.Errorf("Decode failure: %d token id has no corresponding char", id)
61+
return "", errors.New("key not found in id2char")
6162
}
6263
}
6364
return word, nil
6465
}
6566

66-
func specialTokensToBin(specials specialTokens) []byte {
67+
func (s specialTokens) toBinary() []byte {
6768
bytesArray := make([]byte, 16)
68-
binary.BigEndian.PutUint32(bytesArray, uint32(specials.unk))
69-
binary.BigEndian.PutUint32(bytesArray[4:], uint32(specials.pad))
70-
binary.BigEndian.PutUint32(bytesArray[8:], uint32(specials.bos))
71-
binary.BigEndian.PutUint32(bytesArray[12:], uint32(specials.eos))
69+
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk))
70+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad))
71+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos))
72+
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos))
7273
return bytesArray
7374
}
7475

75-
func binToSpecialTokens(bytesArray []byte) specialTokens {
76+
func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
7677
var s specialTokens
78+
if len(bytesArray) < 16 {
79+
logrus.Error("Bytes array length is too small")
80+
return s, errors.New("bytes array is too small")
81+
}
7782
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
7883
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
7984
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
8085
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
81-
return s
86+
return s, nil
8287
}
8388

84-
func ruleToBin(rule rule) []byte {
89+
func (r rule) toBinary() []byte {
8590
bytesArray := make([]byte, 12)
86-
binary.BigEndian.PutUint32(bytesArray, uint32(rule.left))
87-
binary.BigEndian.PutUint32(bytesArray[4:], uint32(rule.right))
88-
binary.BigEndian.PutUint32(bytesArray[8:], uint32(rule.result))
91+
binary.BigEndian.PutUint32(bytesArray, uint32(r.left))
92+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right))
93+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result))
8994
return bytesArray
9095
}
9196

92-
func binToRule(bytesArray []byte) rule {
97+
func binaryToRule(bytesArray []byte) (rule, error) {
9398
var r rule
99+
if len(bytesArray) < 12 {
100+
logrus.Error("Bytes array length is too small")
101+
return r, errors.New("bytes array is too small")
102+
}
94103
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
95104
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
96105
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))
97-
return r
106+
return r, nil
98107
}
99108

100-
// ReadModelFromBinary loads the BPE model from the binary dump
101-
func ReadModelFromBinary(reader io.Reader) (*Model, error) {
102-
bytesReader := bufio.NewReader(reader)
109+
// ReadModel loads the BPE model from the binary dump
110+
func ReadModel(reader io.Reader) (*Model, error) {
103111
buf := make([]byte, 4)
104112
var nChars, nRules int
105-
_, err := bytesReader.Read(buf)
106-
if err != nil {
107-
logrus.Fatal("Broken input: ", err)
113+
if _, err := io.ReadFull(reader, buf); err != nil {
114+
logrus.Error("Broken input: ", err)
108115
return &Model{}, err
109116
}
110117
nChars = int(binary.BigEndian.Uint32(buf))
111-
_, err = bytesReader.Read(buf)
112-
if err != nil {
113-
logrus.Fatal("Broken input: ", err)
118+
if _, err := io.ReadFull(reader, buf); err != nil {
119+
logrus.Error("Broken input: ", err)
114120
return &Model{}, err
115121
}
116122
nRules = int(binary.BigEndian.Uint32(buf))
@@ -119,15 +125,13 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
119125
for i := 0; i < nChars; i++ {
120126
var char rune
121127
var charID TokenID
122-
_, err = bytesReader.Read(buf)
123-
if err != nil {
124-
logrus.Fatal("Broken input: ", err)
128+
if _, err := io.ReadFull(reader, buf); err != nil {
129+
logrus.Error("Broken input: ", err)
125130
return &Model{}, err
126131
}
127132
char = rune(binary.BigEndian.Uint32(buf))
128-
_, err = bytesReader.Read(buf)
129-
if err != nil {
130-
logrus.Fatal("Broken input: ", err)
133+
if _, err := io.ReadFull(reader, buf); err != nil {
134+
logrus.Error("Broken input: ", err)
131135
return &Model{}, err
132136
}
133137
charID = TokenID(binary.BigEndian.Uint32(buf))
@@ -138,27 +142,37 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
138142
}
139143
ruleBuf := make([]byte, 12)
140144
for i := 0; i < nRules; i++ {
141-
_, err = bytesReader.Read(ruleBuf)
142-
if err != nil {
143-
logrus.Fatal("Broken input: ", err)
145+
if _, err := io.ReadFull(reader, ruleBuf); err != nil {
146+
logrus.Error("Broken input: ", err)
144147
return &Model{}, err
145148
}
146-
rule := binToRule(ruleBuf)
149+
rule, err := binaryToRule(ruleBuf)
150+
if err != nil {
151+
return model, err
152+
}
147153
model.rules[i] = rule
154+
if _, ok := model.recipe[rule.left]; !ok {
155+
logrus.Errorf("%d: token id not described before", rule.left)
156+
return model, errors.New("key not found in id2char")
157+
}
158+
if _, ok := model.recipe[rule.right]; !ok {
159+
logrus.Errorf("%d: token id not described before", rule.right)
160+
return model, errors.New("key not found in id2char")
161+
}
148162
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
149163
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
150164
if err != nil {
151-
logrus.Fatal("Unexpected token id inside the rules: ", err)
165+
logrus.Error("Unexpected token id inside the rules: ", err)
152166
return model, err
153167
}
154168
model.revRecipe[resultString] = rule.result
155169
}
156170
specialTokensBuf := make([]byte, 16)
157-
_, err = bytesReader.Read(specialTokensBuf)
158-
if err != nil {
159-
logrus.Fatal("Broken input: ", err)
171+
if _, err := io.ReadFull(reader, specialTokensBuf); err != nil {
172+
logrus.Error("Broken input: ", err)
160173
return &Model{}, err
161174
}
162-
model.specialTokens = binToSpecialTokens(specialTokensBuf)
163-
return model, nil
175+
specials, err := binaryToSpecialTokens(specialTokensBuf)
176+
model.specialTokens = specials
177+
return model, err
164178
}

bpe_test.go

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,54 @@ func TestNewModel(t *testing.T) {
1212
require.Equal(t, 10, len(model.rules))
1313
}
1414

15-
func TestDecodedTokenToString(t *testing.T) {
15+
func TestDecodeToken(t *testing.T) {
1616
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
1717
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
1818
require.NoError(t, err)
1919
require.Equal(t, "abacc", word)
2020
}
2121

22-
func TestSpecialTokensToBin(t *testing.T) {
22+
func TestSpecialTokensToBinary(t *testing.T) {
2323
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
2424
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
25-
require.Equal(t, bytesArray, specialTokensToBin(specials))
25+
require.Equal(t, bytesArray, specials.toBinary())
2626
}
2727

28-
func TestBinToSpecialTokens(t *testing.T) {
28+
func TestBinaryToSpecialTokens(t *testing.T) {
2929
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
30-
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
31-
require.Equal(t, specials, binToSpecialTokens(bytesArray))
30+
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
31+
specials, err := binaryToSpecialTokens(bytesArray)
32+
require.NoError(t, err)
33+
require.Equal(t, expected, specials)
34+
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
35+
specials, err = binaryToSpecialTokens(bytesArray)
36+
require.Error(t, err)
37+
bytesArray = []byte{}
38+
specials, err = binaryToSpecialTokens(bytesArray)
39+
require.Error(t, err)
3240
}
3341

34-
func TestRuleToBin(t *testing.T) {
42+
func TestRuleToBinary(t *testing.T) {
3543
rule := rule{1, 2, 257}
3644
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
37-
require.Equal(t, bytesArray, ruleToBin(rule))
45+
require.Equal(t, bytesArray, rule.toBinary())
3846
}
3947

40-
func TestBinToRule(t *testing.T) {
41-
rule := rule{1, 2, 257}
48+
func TestBinaryToRule(t *testing.T) {
49+
expected := rule{1, 2, 257}
4250
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
43-
require.Equal(t, rule, binToRule(bytesArray))
51+
rule, err := binaryToRule(bytesArray)
52+
require.NoError(t, err)
53+
require.Equal(t, expected, rule)
54+
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
55+
rule, err = binaryToRule(bytesArray)
56+
require.Error(t, err)
57+
bytesArray = []byte{}
58+
rule, err = binaryToRule(bytesArray)
59+
require.Error(t, err)
4460
}
4561

46-
func TestReadModelFromBinary(t *testing.T) {
62+
func TestReadModel(t *testing.T) {
4763
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
4864
0, 0, 0, 99, 0, 0, 0, 6,
4965
0, 0, 0, 98, 0, 0, 0, 7,
@@ -64,7 +80,52 @@ func TestReadModelFromBinary(t *testing.T) {
6480
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
6581
specialTokens{1, 0, 2, 3},
6682
}
67-
model, err := ReadModelFromBinary(reader)
83+
model, err := ReadModel(reader)
6884
require.NoError(t, err)
6985
require.Equal(t, expected, *model)
86+
87+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
88+
0, 0, 0, 99, 0, 0, 0, 6,
89+
0, 0, 0, 98, 0, 0, 0, 7,
90+
0, 0, 0, 95, 0, 0, 0, 4,
91+
0, 0, 0, 100, 0, 0, 0, 5,
92+
0, 0, 0, 97, 0, 0, 0, 8,
93+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
94+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
95+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
96+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
97+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3,
98+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
99+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12})
100+
model, err = ReadModel(reader)
101+
require.NoError(t, err)
102+
require.Equal(t, expected, *model)
103+
104+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
105+
0, 0, 0, 99, 0, 0, 0, 6,
106+
0, 0, 0, 98, 0, 0, 0, 7,
107+
0, 0, 0, 95, 0, 0, 0, 4,
108+
0, 0, 0, 100, 0, 0, 0, 5,
109+
0, 0, 0, 97, 0, 0, 0, 8,
110+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
111+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
112+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
113+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
114+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0})
115+
model, err = ReadModel(reader)
116+
require.Error(t, err)
117+
118+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
119+
0, 0, 0, 99, 0, 0, 0, 6,
120+
0, 0, 0, 98, 0, 0, 0, 7,
121+
0, 0, 0, 95, 0, 0, 0, 4,
122+
0, 0, 0, 100, 0, 0, 0, 5,
123+
0, 0, 0, 97, 0, 0, 0, 8,
124+
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9,
125+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
126+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
127+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
128+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
129+
model, err = ReadModel(reader)
130+
require.Error(t, err)
70131
}

0 commit comments

Comments
 (0)