66
77#include " SerDes/pytorchSerDes.h"
88#include " SerDes/baseSerDes.h"
9+ #include < torch/torch.h>
10+ #include < torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
11+
12+ using TensorVec = std::vector<torch::Tensor>;
913
1014namespace MLBridge {
1115
16+ PytorchSerDes::PytorchSerDes () : BaseSerDes(BaseSerDes::Kind::Pytorch) {
17+ // inputTensors = std::make_shared<std::vector<torch::Tensor>>();
18+ // outputTensors = new std::vector<torch::Tensor>();
19+
20+ // RequestVoid = std::make_shared<std::vector<torch::Tensor>>();
21+ RequestVoid = new TensorVec ();
22+ ResponseVoid = new TensorVec ();
23+ }
24+
1225void PytorchSerDes::setFeature (const std::string &Name, const int Value) {
1326 auto tensor = torch::tensor ({Value}, torch::kInt32 );
14- this ->inputTensors ->push_back (tensor.clone ());
27+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
1528}
1629
1730void PytorchSerDes::setFeature (const std::string &Name, const long Value) {
1831 auto tensor = torch::tensor ({Value}, torch::kInt64 );
19- this ->inputTensors ->push_back (tensor.clone ());
32+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
2033}
2134
2235void PytorchSerDes::setFeature (const std::string &Name, const float Value) {
2336 auto tensor = torch::tensor ({Value}, torch::kFloat32 );
24- this ->inputTensors ->push_back (tensor.clone ());
37+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
2538}
2639
2740void PytorchSerDes::setFeature (const std::string &Name, const double Value) {
2841 auto tensor = torch::tensor ({Value}, torch::kFloat64 );
29- this ->inputTensors ->push_back (tensor.clone ());
42+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
3043}
3144
3245void PytorchSerDes::setFeature (const std::string &Name, const std::string Value) {
3346 std::vector<int8_t > encoded_str (Value.begin (), Value.end ());
3447 auto tensor = torch::tensor (encoded_str, torch::kInt8 );
35- this ->inputTensors ->push_back (tensor.clone ());
48+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
3649}
3750
3851void PytorchSerDes::setFeature (const std::string &Name, const bool Value) {
3952 auto tensor = torch::tensor ({Value}, torch::kBool );
40- this ->inputTensors ->push_back (tensor.clone ());
53+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
4154}
4255
4356void PytorchSerDes::setFeature (const std::string &Name, const std::vector<int > &Value) {
4457 auto tensor = torch::tensor (Value, torch::kInt32 );
45- this ->inputTensors ->push_back (tensor.clone ());
58+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
4659}
4760
4861void PytorchSerDes::setFeature (const std::string &Name, const std::vector<long > &Value) {
4962 auto tensor = torch::tensor (Value, torch::kInt64 );
50- this ->inputTensors ->push_back (tensor.clone ());
63+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
5164}
5265
5366void PytorchSerDes::setFeature (const std::string &Name, const std::vector<float > &Value) {
5467 auto tensor = torch::tensor (Value, torch::kFloat32 );
5568 tensor = tensor.reshape ({1 , Value.size ()});
56- inputTensors ->push_back (tensor.clone ());
69+ reinterpret_cast <TensorVec*>( this -> RequestVoid ) ->push_back (tensor.clone ());
5770}
5871
5972void PytorchSerDes::setFeature (const std::string &Name, const std::vector<double > &Value) {
6073 auto tensor = torch::tensor (Value, torch::kFloat64 );
61- this ->inputTensors ->push_back (tensor.clone ());
74+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
6275}
6376
6477void PytorchSerDes::setFeature (const std::string &Name, const std::vector<std::string> &Value) {
@@ -68,21 +81,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
6881 flat_vec.push_back (' \0 ' ); // Null-terminate each string
6982 }
7083 auto tensor = torch::tensor (flat_vec, torch::kInt8 );
71- this ->inputTensors ->push_back (tensor.clone ());
84+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
7285}
7386
7487void PytorchSerDes::setFeature (const std::string &Name, const std::vector<bool > &Value) {
7588 std::vector<uint8_t > bool_vec (Value.begin (), Value.end ());
7689 auto tensor = torch::tensor (bool_vec, torch::kUInt8 );
77- this ->inputTensors ->push_back (tensor.clone ());
90+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->push_back (tensor.clone ());
7891}
7992
8093// void PytorchSerDes::setRequest(void *Request) {
8194// CompiledModel = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
8295// }
8396
8497void PytorchSerDes::cleanDataStructures () {
85- this ->inputTensors ->clear (); // Clear the input vector
98+ reinterpret_cast <TensorVec*>( this ->RequestVoid ) ->clear (); // Clear the input vector
8699}
87100
88101void *PytorchSerDes::deserializeUntyped (void *Data) {
@@ -91,7 +104,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
91104 }
92105
93106 // Assume Data is a pointer to a vector of tensors
94- std::vector<torch::Tensor> *serializedTensors = reinterpret_cast <std::vector<torch::Tensor> *>(Data);
107+ std::vector<torch::Tensor> *serializedTensors = reinterpret_cast <TensorVec *>(Data);
95108
96109 if (serializedTensors->empty ()) {
97110 return nullptr ;
@@ -100,22 +113,22 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
100113 auto type_vect = serializedTensors->at (0 ).dtype ();
101114
102115 if (type_vect == torch::kInt32 ) {
103- return copyTensorToVect<int32_t >(serializedTensors );
116+ return copyTensorToVect<int32_t >(Data );
104117 }
105118 else if (type_vect == torch::kInt64 ) {
106- return copyTensorToVect<int64_t >(serializedTensors );
119+ return copyTensorToVect<int64_t >(Data );
107120 }
108121 else if (type_vect == torch::kFloat32 ) {
109- return copyTensorToVect<float >(serializedTensors );
122+ return copyTensorToVect<float >(Data );
110123 }
111124 else if (type_vect == torch::kFloat64 ) {
112- return copyTensorToVect<double >(serializedTensors );
125+ return copyTensorToVect<double >(Data );
113126 }
114127 else if (type_vect == torch::kBool ) {
115- return copyTensorToVect<bool >(serializedTensors );
128+ return copyTensorToVect<bool >(Data );
116129 }
117130 else if (type_vect == torch::kInt8 ) {
118- return copyTensorToVect<int8_t >(serializedTensors );
131+ return copyTensorToVect<int8_t >(Data );
119132 }
120133 else {
121134 llvm::errs () << " Unsupported tensor dtype.\n " ;
@@ -124,13 +137,25 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
124137}
125138
126139void *PytorchSerDes::getSerializedData () {
127- std::vector<torch::Tensor> serializedData = *(this ->outputTensors );
140+ return this ->ResponseVoid ; // TODO - check
141+ // TensorVec serializedData = *reinterpret_cast<TensorVec*>(this->ReponseVoid);
128142
129- // Allocate memory for the output and copy the serialized data
130- auto *output = new std::vector<torch::Tensor> (serializedData);
131- return static_cast <void *>(output);
143+ // // Allocate memory for the output and copy the serialized data
144+ // auto *output = new TensorVec (serializedData);
145+ // return static_cast<void *>(output);
132146}
133147
148+ template <typename T>
149+ std::vector<T> *PytorchSerDes::copyTensorToVect (void *serializedTensors) {
150+ auto *ret = new std::vector<T>();
151+ for (const auto &tensor : *reinterpret_cast <TensorVec*>(serializedTensors)) {
152+ ret->insert (ret->end (), tensor.data_ptr <T>(), tensor.data_ptr <T>() + tensor.numel ());
153+ }
154+ return ret;
155+ }
156+
157+ void *PytorchSerDes::getRequest () { return this ->RequestVoid ; }
158+ void *PytorchSerDes::getResponse () { return this ->ResponseVoid ; }
134159
135160
136- } // namespace MLBridge
161+ } // namespace MLBridge
0 commit comments