Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions bpe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package bpe

import (
"bufio"
"fmt"
"io"

"github.com/sirupsen/logrus"
)

type TokenId int32
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated

type DecodedToken []TokenId

type Rule struct {
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
left TokenId
right TokenId
result TokenId
}

type SpecialTokens struct {
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
unk TokenId
pad TokenId
bos TokenId
eos TokenId
}

type Model struct {
char2id map[rune]TokenId
id2char map[TokenId]rune
rules []Rule
recipe map[TokenId]DecodedToken
revRecipe map[string]TokenId
specialTokens SpecialTokens
}

func NewModel(nRules int) *Model {
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
return &Model{
make(map[rune]TokenId),
make(map[TokenId]rune),
make([]Rule, nRules),
make(map[TokenId]DecodedToken),
make(map[string]TokenId),
SpecialTokens{-1, -1, -1, -1},
}
}

func DecodedTokenToString(token DecodedToken, id2char map[TokenId]rune) (string, error) {
word := ""
for _, id := range token {
if char, ok := id2char[id]; ok {
word = word + string(char)
} else {
logrus.Fatalf("%d key not found in id2char", id)
}
}
return word, nil
}

func ReadModel(reader io.Reader) (*Model, error) {
scanner := bufio.NewScanner(reader)
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
var nChars, nRules int
scanner.Scan()
_, err := fmt.Sscanf(scanner.Text(), "%d %d", &nChars, &nRules)
if err != nil {
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
logrus.Fatal("Wrong input format: ", err)
return &Model{}, err
}
model := NewModel(nRules)
model.rules = make([]Rule, nRules)
for i := 0; i < nChars; i++ {
var char rune
var charId TokenId
scanner.Scan()
_, err = fmt.Sscanf(scanner.Text(), "%d %d", &char, &charId)
if err != nil {
logrus.Fatal("Wrong input format: ", err)
return model, err
}
model.char2id[char] = charId
model.id2char[charId] = char
model.recipe[charId] = DecodedToken{charId}
model.revRecipe[string(char)] = charId
}
for i := 0; i < nRules; i++ {
var rule Rule
scanner.Scan()
_, err = fmt.Sscanf(scanner.Text(), "%d %d %d", &rule.left, &rule.right, &rule.result)
if err != nil {
logrus.Fatal("Wrong input format: ", err)
return model, err
}
model.rules[i] = rule
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodedTokenToString(model.recipe[rule.result], model.id2char)
if err != nil {
logrus.Fatal("Unexpected token id inside the rules: ", err)
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
return model, err
}
model.revRecipe[resultString] = rule.result
}
scanner.Scan()
_, err = fmt.Sscanf(scanner.Text(), "%d %d %d %d", &model.specialTokens.unk,
&model.specialTokens.pad, &model.specialTokens.bos, &model.specialTokens.eos)
if err != nil {
logrus.Fatal("Wrong input format: ", err)
return model, err
}
return model, nil
}
44 changes: 44 additions & 0 deletions bpe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package bpe

import (
"github.com/stretchr/testify/require"
"strings"
"testing"
)

func TestNewModel(t *testing.T) {
model := NewModel(10)
require.Equal(t, 10, len(model.rules))
}

func TestDecodedTokenToString(t *testing.T) {
id2char := map[TokenId]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
word, err := DecodedTokenToString(DecodedToken{1, 2, 1, 3, 3}, id2char)
require.NoError(t, err)
require.Equal(t, "abacc", word)
}

func TestReadModel(t *testing.T) {
reader := strings.NewReader(`5 4
99 6
98 7
95 4
100 5
97 8
4 8 9
4 6 10
4 5 11
4 7 12
1 0 2 4`)
expected := Model{
map[rune]TokenId{97: 8, 98: 7, 99: 6, 100: 5, 95: 4},
map[TokenId]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97},
[]Rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}},
map[TokenId]DecodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}},
map[string]TokenId{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4,
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
SpecialTokens{1, 0, 2, 4},
}
Comment thread
irinakhismatullina marked this conversation as resolved.
model, _ := ReadModel(reader)
require.Equal(t, expected, *model)
}
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module github.com/src-d/go-YouTokenToMe

go 1.12

require (
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
)
21 changes: 21 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
Comment thread
irinakhismatullina marked this conversation as resolved.
Outdated
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove go.sum

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I just add it to gitignore? It slips in when I forget to remove it

github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package bpe

import "fmt"

Expand Down