22#include < cassert>
33#include < fstream>
44#include < iostream>
5+ #include < memory>
56#include < string>
67#include < vector>
78
9+
810namespace vkcom {
911using std::string;
1012using std::vector;
1113
14+ template <typename T, typename std::enable_if<std::is_integral<T>::value, int >::type = 0 >
15+ T bin_to_int (const char *val) {
16+ uint32_t ret = static_cast <unsigned char >(val[0 ]);
17+ ret |= static_cast <uint32_t >(static_cast <unsigned char >(val[1 ])) << 8 ;
18+ ret |= static_cast <uint32_t >(static_cast <unsigned char >(val[2 ])) << 16 ;
19+ ret |= static_cast <uint32_t >(static_cast <unsigned char >(val[3 ])) << 24 ;
20+ return ret;
21+ }
22+
23+ template <typename T, typename std::enable_if<std::is_integral<T>::value, int >::type = 0 >
24+ std::unique_ptr<char []> int_to_bin (T val) {
25+ auto u32 = static_cast <uint32_t >(val);
26+ std::unique_ptr<char []> ret (new char [4 ]);
27+ ret[0 ] = u32 & 0xFF ;
28+ ret[1 ] = (u32 >> 8 ) & 0xFF ;
29+ ret[2 ] = (u32 >> 16 ) & 0xFF ;
30+ ret[3 ] = (u32 >> 24 ); // no need for & 0xFF
31+ return std::move (ret);
32+ }
33+
1234void SpecialTokens::dump (std::ofstream &fout) {
13- fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
14- << std::endl;
35+ std::unique_ptr<char []> unk_id_ptr (int_to_bin (unk_id)),
36+ pad_id_ptr (int_to_bin (pad_id)),
37+ bos_id_ptr (int_to_bin (bos_id)),
38+ eos_id_ptr (int_to_bin (eos_id));
39+ fout.write (unk_id_ptr.get (), 4 );
40+ fout.write (pad_id_ptr.get (), 4 );
41+ fout.write (bos_id_ptr.get (), 4 );
42+ fout.write (eos_id_ptr.get (), 4 );
1543}
1644
1745void SpecialTokens::load (std::ifstream &fin) {
18- fin >> unk_id >> pad_id >> bos_id >> eos_id;
46+ char unk_id_bs[4 ], pad_id_bs[4 ], bos_id_bs[4 ], eos_id_bs[4 ];
47+ fin.read (unk_id_bs, 4 );
48+ fin.read (pad_id_bs, 4 );
49+ fin.read (bos_id_bs, 4 );
50+ fin.read (eos_id_bs, 4 );
51+ this ->unk_id = bin_to_int<int >(unk_id_bs);
52+ this ->pad_id = bin_to_int<int >(pad_id_bs);
53+ this ->bos_id = bin_to_int<int >(bos_id_bs);
54+ this ->eos_id = bin_to_int<int >(eos_id_bs);
1955}
2056
2157uint32_t SpecialTokens::max_id () const {
@@ -50,18 +86,33 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {
5086BPE_Rule::BPE_Rule (uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}
5187
5288void BPEState::dump (const string &file_name) {
53- std::ofstream fout (file_name, std::ios::out);
89+ std::ofstream fout (file_name, std::ios::out | std::ios::binary );
5490 if (fout.fail ()) {
5591 std::cerr << " Can't open file: " << file_name << std::endl;
5692 assert (false );
5793 }
58- fout << char2id.size () << " " << rules.size () << std::endl;
59- for (auto s : char2id) {
60- fout << s.first << " " << s.second << std::endl;
61- }
6294
63- for (auto rule : rules) {
64- fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
95+ std::unique_ptr<char []> char2id_ptr (int_to_bin (char2id.size ())),
96+ rules_ptr (int_to_bin (rules.size ()));
97+ fout.write (char2id_ptr.get (), 4 );
98+ fout.write (rules_ptr.get (), 4 );
99+ for (auto &s : char2id) {
100+ std::unique_ptr<char []> first_ptr (int_to_bin (s.first )),
101+ second_ptr (int_to_bin (s.second ));
102+ fout.write (first_ptr.get (), 4 );
103+ fout.write (second_ptr.get (), 4 );
104+ }
105+ for (auto &rule : rules) {
106+ std::unique_ptr<char []> rule_ptr (int_to_bin (rule.x ));
107+ fout.write (rule_ptr.get (), 4 );
108+ }
109+ for (auto &rule : rules) {
110+ std::unique_ptr<char []> rule_ptr (int_to_bin (rule.y ));
111+ fout.write (rule_ptr.get (), 4 );
112+ }
113+ for (auto &rule : rules) {
114+ std::unique_ptr<char []> rule_ptr (int_to_bin (rule.z ));
115+ fout.write (rule_ptr.get (), 4 );
65116 }
66117 special_tokens.dump (fout);
67118 fout.close ();
@@ -70,24 +121,44 @@ void BPEState::dump(const string &file_name) {
70121void BPEState::load (const string &file_name) {
71122 char2id.clear ();
72123 rules.clear ();
73- std::ifstream fin (file_name, std::ios::in);
124+ std::ifstream fin (file_name, std::ios::in | std::ios::binary );
74125 if (fin.fail ()) {
75126 std::cerr << " Error. Can not open file with model: " << file_name
76127 << std::endl;
77128 exit (EXIT_FAILURE);
78129 }
79- int n, m;
80- fin >> n >> m;
130+ char n_bs[4 ], m_bs[4 ];
131+ fin.read (n_bs, 4 );
132+ fin.read (m_bs, 4 );
133+ auto n = bin_to_int<int >(n_bs);
134+ auto m = bin_to_int<int >(m_bs);
81135 for (int i = 0 ; i < n; i++) {
82- uint32_t inner_id;
83- uint32_t utf32_id;
84- fin >> inner_id >> utf32_id;
136+ char inner_id_bs[4 ], utf32_id_bs[4 ];
137+ fin.read (inner_id_bs, 4 );
138+ fin.read (utf32_id_bs, 4 );
139+ auto inner_id = bin_to_int<uint32_t >(inner_id_bs);
140+ auto utf32_id = bin_to_int<uint32_t >(utf32_id_bs);
85141 char2id[inner_id] = utf32_id;
86142 }
143+ std::vector<std::tuple<uint32_t , uint32_t , uint32_t >> rules_xyz (m);
144+ for (int j = 0 ; j < 3 ; j++) {
145+ for (int i = 0 ; i < m; i++) {
146+ char val[4 ];
147+ fin.read (val, 4 );
148+ uint32_t *element;
149+ switch (j) {
150+ case 0 :
151+ element = &std::get<0 >(rules_xyz[i]);
152+ case 1 :
153+ element = &std::get<1 >(rules_xyz[i]);
154+ case 2 :
155+ element = &std::get<2 >(rules_xyz[i]);
156+ }
157+ *element = bin_to_int<uint32_t >(val);
158+ }
159+ }
87160 for (int i = 0 ; i < m; i++) {
88- uint32_t x, y, z;
89- fin >> x >> y >> z;
90- rules.emplace_back (x, y, z);
161+ rules.emplace_back (std::get<0 >(rules_xyz[i]), std::get<1 >(rules_xyz[i]), std::get<2 >(rules_xyz[i]));
91162 }
92163 special_tokens.load (fin);
93164 fin.close ();
0 commit comments