Skip to content

Commit be8076a

Browse files
committed
Removing torch includes from pytorch runner/serdes headers
1 parent e6ae5a1 commit be8076a

8 files changed

Lines changed: 257 additions & 116 deletions

File tree

MLModelRunner/CMakeLists.txt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ if(NOT PROTOS_DIRECTORY STREQUAL "")
44
add_subdirectory(gRPCModelRunner)
55
endif()
66
add_subdirectory(ONNXModelRunner)
7+
add_subdirectory(PTModelRunner)
78
add_subdirectory(C)
89

910
# # For up-to-date instructions for installing the TFLite dependency, refer to
@@ -44,13 +45,8 @@ else()
4445
add_library(ModelRunnerLib OBJECT MLModelRunner.cpp PipeModelRunner.cpp)
4546
endif(LLVM_MLBRIDGE)
4647

47-
target_link_libraries(ModelRunnerLib PUBLIC ModelRunnerUtils ONNXModelRunnerLib)
48+
target_link_libraries(ModelRunnerLib PUBLIC ModelRunnerUtils ONNXModelRunnerLib PTModelRunnerLib)
4849

4950
if(NOT PROTOS_DIRECTORY STREQUAL "")
5051
target_link_libraries(ModelRunnerLib PUBLIC gRPCModelRunnerLib)
51-
endif()
52-
set_property(TARGET ModelRunnerLib PROPERTY POSITION_INDEPENDENT_CODE 1)
53-
54-
find_package(Torch REQUIRED)
55-
target_link_libraries(ModelRunnerLib PRIVATE ${TORCH_LIBRARIES})
56-
target_compile_options (ModelRunnerLib PRIVATE -fexceptions)
52+
endif()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
if(LLVM_MLBRIDGE)
2+
add_llvm_library(PTModelRunnerLib PTModelRunner.cpp)
3+
else()
4+
add_library(PTModelRunnerLib OBJECT PTModelRunner.cpp)
5+
endif(LLVM_MLBRIDGE)
6+
7+
find_package(Torch REQUIRED)
8+
target_link_libraries(PTModelRunnerLib PRIVATE ${TORCH_LIBRARIES})
9+
target_compile_options (PTModelRunnerLib PRIVATE -fexceptions)
10+
11+
# message("HERE FIND TORCH ${TORCH_INCLUDE_DIRS}")
12+
# find_package(Torch REQUIRED)
13+
target_include_directories(PTModelRunnerLib PRIVATE ${TORCH_INCLUDE_DIRS})
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//=== PTModelRunner.cpp - PTModelRunner Implementation ---*- C++ -*-===//
2+
//
3+
// Part of the MLCompilerBridge Project
4+
//
5+
//===------------------===//
6+
7+
#include "MLModelRunner/PTModelRunner.h"
8+
9+
#include "MLModelRunner/MLModelRunner.h"
10+
#include "SerDes/TensorSpec.h"
11+
// #include "SerDes/baseSerDes.h"
12+
#include "SerDes/pytorchSerDes.h"
13+
#include "llvm/Support/ErrorHandling.h"
14+
#include <torch/torch.h>
15+
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h> // or model_container_runner_cuda.h for CUDA
16+
17+
#include <memory>
18+
#include <vector>
19+
20+
using TensorVec = std::vector<torch::Tensor>;
21+
22+
namespace MLBridge
23+
{
24+
25+
PTModelRunner::PTModelRunner(const std::string &modelPath, llvm::LLVMContext &Ctx)
26+
: MLModelRunner(MLModelRunner::Kind::PTAOT, BaseSerDes::Kind::Pytorch, &Ctx)
27+
{
28+
this->SerDes = new PytorchSerDes();
29+
30+
c10::InferenceMode mode;
31+
this->CompiledModel = new torch::inductor::AOTIModelContainerRunnerCpu(modelPath);
32+
}
33+
34+
35+
36+
void *PTModelRunner::evaluateUntyped()
37+
{
38+
39+
if ((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())).empty())
40+
{
41+
llvm::errs() << "Input vector is empty.\n";
42+
return nullptr;
43+
}
44+
45+
try
46+
{
47+
48+
std::vector<torch::Tensor> *outputTensors = reinterpret_cast<std::vector<torch::Tensor>*>(this->SerDes->getResponse());
49+
auto outputs = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu*>(this->CompiledModel)->run((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())));
50+
for (auto i = outputs.begin(); i != outputs.end(); ++i)
51+
(*(outputTensors)).push_back(*i);
52+
void *rawData = this->SerDes->deserializeUntyped(outputTensors);
53+
return rawData;
54+
}
55+
catch (const c10::Error &e)
56+
{
57+
llvm::errs() << "Error during model evaluation: " << e.what() << "\n";
58+
return nullptr;
59+
}
60+
}
61+
62+
template <typename U, typename T, typename... Types>
63+
void PTModelRunner::populateFeatures(const std::pair<U, T> &var1,
64+
const std::pair<U, Types> &...var2)
65+
{
66+
SerDes->setFeature(var1.first, var1.second);
67+
PTModelRunner::populateFeatures(var2...);
68+
}
69+
70+
} // namespace MLBridge

SerDes/CMakeLists.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
find_package(Torch REQUIRED)
22

3+
add_subdirectory(pytorchSerDes)
4+
35
set(protobuf_MODULE_COMPATIBLE TRUE)
46
find_package(Protobuf CONFIG REQUIRED)
57

@@ -9,18 +11,15 @@ TensorSpec.cpp
911
jsonSerDes.cpp
1012
bitstreamSerDes.cpp
1113
protobufSerDes.cpp
12-
pytorchSerDes.cpp
1314
tensorflowSerDes.cpp
1415
JSON.cpp
1516
)
1617

1718
else()
18-
add_library(SerDesLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp protobufSerDes.cpp tensorflowSerDes.cpp JSON.cpp pytorchSerDes.cpp)
19-
20-
# add_library(SerDesCLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp JSON.cpp)
19+
add_library(SerDesLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp protobufSerDes.cpp tensorflowSerDes.cpp JSON.cpp )
2120
endif()
2221

23-
target_link_libraries(SerDesLib PRIVATE ${TORCH_LIBRARIES})
24-
target_compile_options(SerDesLib PRIVATE -fexceptions)
22+
target_link_libraries(SerDesLib PUBLIC PyTorchSerDesLib)
23+
2524
target_include_directories(SerDesLib PUBLIC ${TENSORFLOW_AOT_PATH}/include)
2625
target_link_libraries(SerDesLib PRIVATE tf_xla_runtime)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
find_package(Torch REQUIRED)
2+
3+
# set(protobuf_MODULE_COMPATIBLE TRUE)
4+
# find_package(Protobuf CONFIG REQUIRED)
5+
6+
if(LLVM_MLBRIDGE)
7+
add_llvm_library(PyTorchSerDesLib
8+
pytorchSerDes.cpp
9+
)
10+
11+
else()
12+
add_library(PyTorchSerDesLib OBJECT pytorchSerDes.cpp)
13+
14+
# add_library(SerDesCLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp JSON.cpp)
15+
endif()
16+
17+
target_link_libraries(PyTorchSerDesLib PRIVATE ${TORCH_LIBRARIES})
18+
target_compile_options(PyTorchSerDesLib PRIVATE -fexceptions)
19+
target_include_directories(PyTorchSerDesLib PUBLIC ${TORCH_INCLUDE_DIRS})
Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,72 @@
66

77
#include "SerDes/pytorchSerDes.h"
88
#include "SerDes/baseSerDes.h"
9+
#include <torch/torch.h>
10+
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
11+
12+
using TensorVec = std::vector<torch::Tensor>;
913

1014
namespace MLBridge {
1115

16+
PytorchSerDes::PytorchSerDes() : BaseSerDes(BaseSerDes::Kind::Pytorch) {
17+
// inputTensors = std::make_shared<std::vector<torch::Tensor>>();
18+
// outputTensors = new std::vector<torch::Tensor>();
19+
20+
// RequestVoid = std::make_shared<std::vector<torch::Tensor>>();
21+
RequestVoid = new TensorVec();
22+
ResponseVoid = new TensorVec();
23+
}
24+
1225
void PytorchSerDes::setFeature(const std::string &Name, const int Value) {
1326
auto tensor = torch::tensor({Value}, torch::kInt32);
14-
this->inputTensors->push_back(tensor.clone());
27+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
1528
}
1629

1730
void PytorchSerDes::setFeature(const std::string &Name, const long Value) {
1831
auto tensor = torch::tensor({Value}, torch::kInt64);
19-
this->inputTensors->push_back(tensor.clone());
32+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
2033
}
2134

2235
void PytorchSerDes::setFeature(const std::string &Name, const float Value) {
2336
auto tensor = torch::tensor({Value}, torch::kFloat32);
24-
this->inputTensors->push_back(tensor.clone());
37+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
2538
}
2639

2740
void PytorchSerDes::setFeature(const std::string &Name, const double Value) {
2841
auto tensor = torch::tensor({Value}, torch::kFloat64);
29-
this->inputTensors->push_back(tensor.clone());
42+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3043
}
3144

3245
void PytorchSerDes::setFeature(const std::string &Name, const std::string Value) {
3346
std::vector<int8_t> encoded_str(Value.begin(), Value.end());
3447
auto tensor = torch::tensor(encoded_str, torch::kInt8);
35-
this->inputTensors->push_back(tensor.clone());
48+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3649
}
3750

3851
void PytorchSerDes::setFeature(const std::string &Name, const bool Value) {
3952
auto tensor = torch::tensor({Value}, torch::kBool);
40-
this->inputTensors->push_back(tensor.clone());
53+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4154
}
4255

4356
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<int> &Value) {
4457
auto tensor = torch::tensor(Value, torch::kInt32);
45-
this->inputTensors->push_back(tensor.clone());
58+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4659
}
4760

4861
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<long> &Value) {
4962
auto tensor = torch::tensor(Value, torch::kInt64);
50-
this->inputTensors->push_back(tensor.clone());
63+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5164
}
5265

5366
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<float> &Value) {
5467
auto tensor = torch::tensor(Value, torch::kFloat32);
5568
tensor = tensor.reshape({1, Value.size()});
56-
inputTensors->push_back(tensor.clone());
69+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5770
}
5871

5972
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<double> &Value) {
6073
auto tensor = torch::tensor(Value, torch::kFloat64);
61-
this->inputTensors->push_back(tensor.clone());
74+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
6275
}
6376

6477
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::string> &Value) {
@@ -68,21 +81,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
6881
flat_vec.push_back('\0'); // Null-terminate each string
6982
}
7083
auto tensor = torch::tensor(flat_vec, torch::kInt8);
71-
this->inputTensors->push_back(tensor.clone());
84+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
7285
}
7386

7487
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<bool> &Value) {
7588
std::vector<uint8_t> bool_vec(Value.begin(), Value.end());
7689
auto tensor = torch::tensor(bool_vec, torch::kUInt8);
77-
this->inputTensors->push_back(tensor.clone());
90+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
7891
}
7992

8093
// void PytorchSerDes::setRequest(void *Request) {
8194
// CompiledModel = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
8295
// }
8396

8497
void PytorchSerDes::cleanDataStructures() {
85-
this->inputTensors->clear(); // Clear the input vector
98+
reinterpret_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
8699
}
87100

88101
void *PytorchSerDes::deserializeUntyped(void *Data) {
@@ -91,7 +104,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
91104
}
92105

93106
// Assume Data is a pointer to a vector of tensors
94-
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<std::vector<torch::Tensor> *>(Data);
107+
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<TensorVec *>(Data);
95108

96109
if (serializedTensors->empty()) {
97110
return nullptr;
@@ -100,22 +113,22 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
100113
auto type_vect = serializedTensors->at(0).dtype();
101114

102115
if (type_vect == torch::kInt32) {
103-
return copyTensorToVect<int32_t>(serializedTensors);
116+
return copyTensorToVect<int32_t>(Data);
104117
}
105118
else if (type_vect == torch::kInt64) {
106-
return copyTensorToVect<int64_t>(serializedTensors);
119+
return copyTensorToVect<int64_t>(Data);
107120
}
108121
else if (type_vect == torch::kFloat32) {
109-
return copyTensorToVect<float>(serializedTensors);
122+
return copyTensorToVect<float>(Data);
110123
}
111124
else if (type_vect == torch::kFloat64) {
112-
return copyTensorToVect<double>(serializedTensors);
125+
return copyTensorToVect<double>(Data);
113126
}
114127
else if (type_vect == torch::kBool) {
115-
return copyTensorToVect<bool>(serializedTensors);
128+
return copyTensorToVect<bool>(Data);
116129
}
117130
else if (type_vect == torch::kInt8) {
118-
return copyTensorToVect<int8_t>(serializedTensors);
131+
return copyTensorToVect<int8_t>(Data);
119132
}
120133
else {
121134
llvm::errs() << "Unsupported tensor dtype.\n";
@@ -124,13 +137,25 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
124137
}
125138

126139
void *PytorchSerDes::getSerializedData() {
127-
std::vector<torch::Tensor> serializedData = *(this->outputTensors);
140+
return this->ResponseVoid; // TODO - check
141+
// TensorVec serializedData = *reinterpret_cast<TensorVec*>(this->ReponseVoid);
128142

129-
// Allocate memory for the output and copy the serialized data
130-
auto *output = new std::vector<torch::Tensor>(serializedData);
131-
return static_cast<void *>(output);
143+
// // Allocate memory for the output and copy the serialized data
144+
// auto *output = new TensorVec(serializedData);
145+
// return static_cast<void *>(output);
132146
}
133147

148+
template <typename T>
149+
std::vector<T> *PytorchSerDes::copyTensorToVect(void *serializedTensors) {
150+
auto *ret = new std::vector<T>();
151+
for (const auto &tensor : *reinterpret_cast<TensorVec*>(serializedTensors)) {
152+
ret->insert(ret->end(), tensor.data_ptr<T>(), tensor.data_ptr<T>() + tensor.numel());
153+
}
154+
return ret;
155+
}
156+
157+
void *PytorchSerDes::getRequest() { return this->RequestVoid; }
158+
void *PytorchSerDes::getResponse() { return this->ResponseVoid; }
134159

135160

136-
} // namespace MLBridge
161+
} // namespace MLBridge

0 commit comments

Comments
 (0)