Skip to content
This repository was archived by the owner on Apr 23, 2024. It is now read-only.

Commit 511ae6d

Browse files
committed
Change the model format to binary
We are using big endian numbers under the hood. Rule's x, y and z are written by plane, not interleaved. Signed-off-by: Vadim Markovtsev <vadim@sourced.tech>
1 parent f5f4bf3 commit 511ae6d

1 file changed

Lines changed: 90 additions & 19 deletions

File tree

youtokentome/cpp/utils.cpp

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,56 @@
22
#include <cassert>
33
#include <fstream>
44
#include <iostream>
5+
#include <memory>
56
#include <string>
67
#include <vector>
78

9+
810
namespace vkcom {
911
using std::string;
1012
using std::vector;
1113

14+
template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
15+
T bin_to_int(const char *val) {
16+
uint32_t ret = static_cast<unsigned char>(val[0]);
17+
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[1])) << 8;
18+
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[2])) << 16;
19+
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[3])) << 24;
20+
return ret;
21+
}
22+
23+
template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
24+
std::unique_ptr<char[]> int_to_bin(T val) {
25+
auto u32 = static_cast<uint32_t>(val);
26+
std::unique_ptr<char[]> ret(new char[4]);
27+
ret[0] = u32 & 0xFF;
28+
ret[1] = (u32 >> 8) & 0xFF;
29+
ret[2] = (u32 >> 16) & 0xFF;
30+
ret[3] = (u32 >> 24); // no need for & 0xFF
31+
return std::move(ret);
32+
}
33+
1234
void SpecialTokens::dump(std::ofstream &fout) {
13-
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
14-
<< std::endl;
35+
std::unique_ptr<char[]> unk_id_ptr(int_to_bin(unk_id)),
36+
pad_id_ptr(int_to_bin(pad_id)),
37+
bos_id_ptr(int_to_bin(bos_id)),
38+
eos_id_ptr(int_to_bin(eos_id));
39+
fout.write(unk_id_ptr.get(), 4);
40+
fout.write(pad_id_ptr.get(), 4);
41+
fout.write(bos_id_ptr.get(), 4);
42+
fout.write(eos_id_ptr.get(), 4);
1543
}
1644

1745
void SpecialTokens::load(std::ifstream &fin) {
18-
fin >> unk_id >> pad_id >> bos_id >> eos_id;
46+
char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4];
47+
fin.read(unk_id_bs, 4);
48+
fin.read(pad_id_bs, 4);
49+
fin.read(bos_id_bs, 4);
50+
fin.read(eos_id_bs, 4);
51+
this->unk_id = bin_to_int<int>(unk_id_bs);
52+
this->pad_id = bin_to_int<int>(pad_id_bs);
53+
this->bos_id = bin_to_int<int>(bos_id_bs);
54+
this->eos_id = bin_to_int<int>(eos_id_bs);
1955
}
2056

2157
uint32_t SpecialTokens::max_id() const {
@@ -50,18 +86,33 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {
5086
BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}
5187

5288
void BPEState::dump(const string &file_name) {
53-
std::ofstream fout(file_name, std::ios::out);
89+
std::ofstream fout(file_name, std::ios::out | std::ios::binary);
5490
if (fout.fail()) {
5591
std::cerr << "Can't open file: " << file_name << std::endl;
5692
assert(false);
5793
}
58-
fout << char2id.size() << " " << rules.size() << std::endl;
59-
for (auto s : char2id) {
60-
fout << s.first << " " << s.second << std::endl;
61-
}
6294

63-
for (auto rule : rules) {
64-
fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
95+
std::unique_ptr<char[]> char2id_ptr(int_to_bin(char2id.size())),
96+
rules_ptr(int_to_bin(rules.size()));
97+
fout.write(char2id_ptr.get(), 4);
98+
fout.write(rules_ptr.get(), 4);
99+
for (auto &s : char2id) {
100+
std::unique_ptr<char[]> first_ptr(int_to_bin(s.first)),
101+
second_ptr(int_to_bin(s.second));
102+
fout.write(first_ptr.get(), 4);
103+
fout.write(second_ptr.get(), 4);
104+
}
105+
for (auto &rule : rules) {
106+
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.x));
107+
fout.write(rule_ptr.get(), 4);
108+
}
109+
for (auto &rule : rules) {
110+
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.y));
111+
fout.write(rule_ptr.get(), 4);
112+
}
113+
for (auto &rule : rules) {
114+
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.z));
115+
fout.write(rule_ptr.get(), 4);
65116
}
66117
special_tokens.dump(fout);
67118
fout.close();
@@ -70,24 +121,44 @@ void BPEState::dump(const string &file_name) {
70121
void BPEState::load(const string &file_name) {
71122
char2id.clear();
72123
rules.clear();
73-
std::ifstream fin(file_name, std::ios::in);
124+
std::ifstream fin(file_name, std::ios::in | std::ios::binary);
74125
if (fin.fail()) {
75126
std::cerr << "Error. Can not open file with model: " << file_name
76127
<< std::endl;
77128
exit(EXIT_FAILURE);
78129
}
79-
int n, m;
80-
fin >> n >> m;
130+
char n_bs[4], m_bs[4];
131+
fin.read(n_bs, 4);
132+
fin.read(m_bs, 4);
133+
auto n = bin_to_int<int>(n_bs);
134+
auto m = bin_to_int<int>(m_bs);
81135
for (int i = 0; i < n; i++) {
82-
uint32_t inner_id;
83-
uint32_t utf32_id;
84-
fin >> inner_id >> utf32_id;
136+
char inner_id_bs[4], utf32_id_bs[4];
137+
fin.read(inner_id_bs, 4);
138+
fin.read(utf32_id_bs, 4);
139+
auto inner_id = bin_to_int<uint32_t>(inner_id_bs);
140+
auto utf32_id = bin_to_int<uint32_t>(utf32_id_bs);
85141
char2id[inner_id] = utf32_id;
86142
}
143+
std::vector<std::tuple<uint32_t, uint32_t, uint32_t>> rules_xyz(m);
144+
for (int j = 0; j < 3; j++) {
145+
for (int i = 0; i < m; i++) {
146+
char val[4];
147+
fin.read(val, 4);
148+
uint32_t *element;
149+
switch (j) {
150+
case 0:
151+
element = &std::get<0>(rules_xyz[i]);
152+
case 1:
153+
element = &std::get<1>(rules_xyz[i]);
154+
case 2:
155+
element = &std::get<2>(rules_xyz[i]);
156+
}
157+
*element = bin_to_int<uint32_t>(val);
158+
}
159+
}
87160
for (int i = 0; i < m; i++) {
88-
uint32_t x, y, z;
89-
fin >> x >> y >> z;
90-
rules.emplace_back(x, y, z);
161+
rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i]));
91162
}
92163
special_tokens.load(fin);
93164
fin.close();

0 commit comments

Comments
 (0)