Skip to content

Commit 87388a2

Browse files
authored
Merge pull request #2 from irinakhismatullina/read
Add function for reading the model dump
2 parents 76ebe6c + 5a79713 commit 87388a2

6 files changed

Lines changed: 316 additions & 8 deletions

File tree

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ install:
1616
script:
1717
- make install-dev-deps
1818
- make check-style
19-
- make test
19+
- make test-coverage
20+
- make codecov
2021

2122
matrix:
2223
fast_finish: true

bpe.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package bpe
2+
3+
import (
4+
"encoding/binary"
5+
"errors"
6+
"io"
7+
8+
"github.com/sirupsen/logrus"
9+
)
10+
11+
// TokenID is a numerical identitier of the subword token
12+
type TokenID uint32
13+
14+
// EncodedToken is a sequence of subword tokens ids
15+
type EncodedToken []TokenID
16+
17+
type rule struct {
18+
left TokenID
19+
right TokenID
20+
result TokenID
21+
}
22+
23+
type specialTokens struct {
24+
unk int32
25+
pad int32
26+
bos int32
27+
eos int32
28+
}
29+
30+
// Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences
31+
// of most frequent subword tokens
32+
type Model struct {
33+
char2id map[rune]TokenID
34+
id2char map[TokenID]rune
35+
rules []rule
36+
recipe map[TokenID]EncodedToken
37+
revRecipe map[string]TokenID
38+
specialTokens specialTokens
39+
}
40+
41+
func newModel(nRules int) *Model {
42+
return &Model{
43+
make(map[rune]TokenID),
44+
make(map[TokenID]rune),
45+
make([]rule, nRules),
46+
make(map[TokenID]EncodedToken),
47+
make(map[string]TokenID),
48+
specialTokens{-1, -1, -1, -1},
49+
}
50+
}
51+
52+
// DecodeToken converts the sequence of chars' ids into the string -
53+
// sequence of the corresponding chars
54+
func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
55+
word := ""
56+
for _, id := range token {
57+
if char, ok := id2char[id]; ok {
58+
word = word + string(char)
59+
} else {
60+
logrus.Errorf("Decode failure: %d token id has no corresponding char", id)
61+
return "", errors.New("key not found in id2char")
62+
}
63+
}
64+
return word, nil
65+
}
66+
67+
func (s specialTokens) toBinary() []byte {
68+
bytesArray := make([]byte, 16)
69+
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk))
70+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad))
71+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos))
72+
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos))
73+
return bytesArray
74+
}
75+
76+
func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
77+
var s specialTokens
78+
if len(bytesArray) < 16 {
79+
logrus.Error("Bytes array length is too small")
80+
return s, errors.New("bytes array is too small")
81+
}
82+
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
83+
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
84+
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
85+
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
86+
return s, nil
87+
}
88+
89+
func (r rule) toBinary() []byte {
90+
bytesArray := make([]byte, 12)
91+
binary.BigEndian.PutUint32(bytesArray, uint32(r.left))
92+
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right))
93+
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result))
94+
return bytesArray
95+
}
96+
97+
func binaryToRule(bytesArray []byte) (rule, error) {
98+
var r rule
99+
if len(bytesArray) < 12 {
100+
logrus.Error("Bytes array length is too small")
101+
return r, errors.New("bytes array is too small")
102+
}
103+
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
104+
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
105+
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))
106+
return r, nil
107+
}
108+
109+
// ReadModel loads the BPE model from the binary dump
110+
func ReadModel(reader io.Reader) (*Model, error) {
111+
buf := make([]byte, 4)
112+
var nChars, nRules int
113+
if _, err := io.ReadFull(reader, buf); err != nil {
114+
logrus.Error("Broken input: ", err)
115+
return &Model{}, err
116+
}
117+
nChars = int(binary.BigEndian.Uint32(buf))
118+
if _, err := io.ReadFull(reader, buf); err != nil {
119+
logrus.Error("Broken input: ", err)
120+
return &Model{}, err
121+
}
122+
nRules = int(binary.BigEndian.Uint32(buf))
123+
124+
model := newModel(nRules)
125+
for i := 0; i < nChars; i++ {
126+
var char rune
127+
var charID TokenID
128+
if _, err := io.ReadFull(reader, buf); err != nil {
129+
logrus.Error("Broken input: ", err)
130+
return &Model{}, err
131+
}
132+
char = rune(binary.BigEndian.Uint32(buf))
133+
if _, err := io.ReadFull(reader, buf); err != nil {
134+
logrus.Error("Broken input: ", err)
135+
return &Model{}, err
136+
}
137+
charID = TokenID(binary.BigEndian.Uint32(buf))
138+
model.char2id[char] = charID
139+
model.id2char[charID] = char
140+
model.recipe[charID] = EncodedToken{charID}
141+
model.revRecipe[string(char)] = charID
142+
}
143+
ruleBuf := make([]byte, 12)
144+
for i := 0; i < nRules; i++ {
145+
if _, err := io.ReadFull(reader, ruleBuf); err != nil {
146+
logrus.Error("Broken input: ", err)
147+
return &Model{}, err
148+
}
149+
rule, err := binaryToRule(ruleBuf)
150+
if err != nil {
151+
return model, err
152+
}
153+
model.rules[i] = rule
154+
if _, ok := model.recipe[rule.left]; !ok {
155+
logrus.Errorf("%d: token id not described before", rule.left)
156+
return model, errors.New("key not found in id2char")
157+
}
158+
if _, ok := model.recipe[rule.right]; !ok {
159+
logrus.Errorf("%d: token id not described before", rule.right)
160+
return model, errors.New("key not found in id2char")
161+
}
162+
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
163+
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
164+
if err != nil {
165+
logrus.Error("Unexpected token id inside the rules: ", err)
166+
return model, err
167+
}
168+
model.revRecipe[resultString] = rule.result
169+
}
170+
specialTokensBuf := make([]byte, 16)
171+
if _, err := io.ReadFull(reader, specialTokensBuf); err != nil {
172+
logrus.Error("Broken input: ", err)
173+
return &Model{}, err
174+
}
175+
specials, err := binaryToSpecialTokens(specialTokensBuf)
176+
model.specialTokens = specials
177+
return model, err
178+
}

bpe_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package bpe
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestNewModel(t *testing.T) {
11+
model := newModel(10)
12+
require.Equal(t, 10, len(model.rules))
13+
}
14+
15+
func TestDecodeToken(t *testing.T) {
16+
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
17+
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
18+
require.NoError(t, err)
19+
require.Equal(t, "abacc", word)
20+
}
21+
22+
func TestSpecialTokensToBinary(t *testing.T) {
23+
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
24+
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
25+
require.Equal(t, bytesArray, specials.toBinary())
26+
}
27+
28+
func TestBinaryToSpecialTokens(t *testing.T) {
29+
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
30+
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
31+
specials, err := binaryToSpecialTokens(bytesArray)
32+
require.NoError(t, err)
33+
require.Equal(t, expected, specials)
34+
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
35+
specials, err = binaryToSpecialTokens(bytesArray)
36+
require.Error(t, err)
37+
bytesArray = []byte{}
38+
specials, err = binaryToSpecialTokens(bytesArray)
39+
require.Error(t, err)
40+
}
41+
42+
func TestRuleToBinary(t *testing.T) {
43+
rule := rule{1, 2, 257}
44+
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
45+
require.Equal(t, bytesArray, rule.toBinary())
46+
}
47+
48+
func TestBinaryToRule(t *testing.T) {
49+
expected := rule{1, 2, 257}
50+
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
51+
rule, err := binaryToRule(bytesArray)
52+
require.NoError(t, err)
53+
require.Equal(t, expected, rule)
54+
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
55+
rule, err = binaryToRule(bytesArray)
56+
require.Error(t, err)
57+
bytesArray = []byte{}
58+
rule, err = binaryToRule(bytesArray)
59+
require.Error(t, err)
60+
}
61+
62+
func TestReadModel(t *testing.T) {
63+
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
64+
0, 0, 0, 99, 0, 0, 0, 6,
65+
0, 0, 0, 98, 0, 0, 0, 7,
66+
0, 0, 0, 95, 0, 0, 0, 4,
67+
0, 0, 0, 100, 0, 0, 0, 5,
68+
0, 0, 0, 97, 0, 0, 0, 8,
69+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
70+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
71+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
72+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
73+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
74+
expected := Model{
75+
map[rune]TokenID{97: 8, 98: 7, 99: 6, 100: 5, 95: 4},
76+
map[TokenID]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97},
77+
[]rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}},
78+
map[TokenID]EncodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}},
79+
map[string]TokenID{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4,
80+
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
81+
specialTokens{1, 0, 2, 3},
82+
}
83+
model, err := ReadModel(reader)
84+
require.NoError(t, err)
85+
require.Equal(t, expected, *model)
86+
87+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
88+
0, 0, 0, 99, 0, 0, 0, 6,
89+
0, 0, 0, 98, 0, 0, 0, 7,
90+
0, 0, 0, 95, 0, 0, 0, 4,
91+
0, 0, 0, 100, 0, 0, 0, 5,
92+
0, 0, 0, 97, 0, 0, 0, 8,
93+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
94+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
95+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
96+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
97+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3,
98+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
99+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12})
100+
model, err = ReadModel(reader)
101+
require.NoError(t, err)
102+
require.Equal(t, expected, *model)
103+
104+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
105+
0, 0, 0, 99, 0, 0, 0, 6,
106+
0, 0, 0, 98, 0, 0, 0, 7,
107+
0, 0, 0, 95, 0, 0, 0, 4,
108+
0, 0, 0, 100, 0, 0, 0, 5,
109+
0, 0, 0, 97, 0, 0, 0, 8,
110+
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
111+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
112+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
113+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
114+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0})
115+
model, err = ReadModel(reader)
116+
require.Error(t, err)
117+
118+
reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
119+
0, 0, 0, 99, 0, 0, 0, 6,
120+
0, 0, 0, 98, 0, 0, 0, 7,
121+
0, 0, 0, 95, 0, 0, 0, 4,
122+
0, 0, 0, 100, 0, 0, 0, 5,
123+
0, 0, 0, 97, 0, 0, 0, 8,
124+
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9,
125+
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
126+
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
127+
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
128+
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
129+
model, err = ReadModel(reader)
130+
require.Error(t, err)
131+
}

go.mod

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
module github.com/src-d/go-YouTokenToMe
22

33
go 1.12
4+
5+
require (
6+
github.com/sirupsen/logrus v1.4.2
7+
github.com/stretchr/testify v1.4.0
8+
)

go.sum

Whitespace-only changes.

main.go

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)