Skip to content

Commit 6f9180c

Browse files
committed
Changed reinterpret_cast to static_cast
1 parent be8076a commit 6f9180c

2 files changed

Lines changed: 20 additions & 20 deletions

File tree

MLModelRunner/PTModelRunner/PTModelRunner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace MLBridge
3636
void *PTModelRunner::evaluateUntyped()
3737
{
3838

39-
if ((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())).empty())
39+
if ((*static_cast<TensorVec*>(this->SerDes->getRequest())).empty())
4040
{
4141
llvm::errs() << "Input vector is empty.\n";
4242
return nullptr;
@@ -45,8 +45,8 @@ namespace MLBridge
4545
try
4646
{
4747

48-
std::vector<torch::Tensor> *outputTensors = reinterpret_cast<std::vector<torch::Tensor>*>(this->SerDes->getResponse());
49-
auto outputs = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu*>(this->CompiledModel)->run((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())));
48+
std::vector<torch::Tensor> *outputTensors = static_cast<std::vector<torch::Tensor>*>(this->SerDes->getResponse());
49+
auto outputs = static_cast<torch::inductor::AOTIModelContainerRunnerCpu*>(this->CompiledModel)->run((*static_cast<TensorVec*>(this->SerDes->getRequest())));
5050
for (auto i = outputs.begin(); i != outputs.end(); ++i)
5151
(*(outputTensors)).push_back(*i);
5252
void *rawData = this->SerDes->deserializeUntyped(outputTensors);

SerDes/pytorchSerDes/pytorchSerDes.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,54 @@ PytorchSerDes::PytorchSerDes() : BaseSerDes(BaseSerDes::Kind::Pytorch) {
2424

2525
void PytorchSerDes::setFeature(const std::string &Name, const int Value) {
2626
auto tensor = torch::tensor({Value}, torch::kInt32);
27-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
27+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
2828
}
2929

3030
void PytorchSerDes::setFeature(const std::string &Name, const long Value) {
3131
auto tensor = torch::tensor({Value}, torch::kInt64);
32-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
32+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3333
}
3434

3535
void PytorchSerDes::setFeature(const std::string &Name, const float Value) {
3636
auto tensor = torch::tensor({Value}, torch::kFloat32);
37-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
37+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3838
}
3939

4040
void PytorchSerDes::setFeature(const std::string &Name, const double Value) {
4141
auto tensor = torch::tensor({Value}, torch::kFloat64);
42-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
42+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4343
}
4444

4545
void PytorchSerDes::setFeature(const std::string &Name, const std::string Value) {
4646
std::vector<int8_t> encoded_str(Value.begin(), Value.end());
4747
auto tensor = torch::tensor(encoded_str, torch::kInt8);
48-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
48+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4949
}
5050

5151
void PytorchSerDes::setFeature(const std::string &Name, const bool Value) {
5252
auto tensor = torch::tensor({Value}, torch::kBool);
53-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
53+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5454
}
5555

5656
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<int> &Value) {
5757
auto tensor = torch::tensor(Value, torch::kInt32);
58-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
58+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5959
}
6060

6161
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<long> &Value) {
6262
auto tensor = torch::tensor(Value, torch::kInt64);
63-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
63+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
6464
}
6565

6666
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<float> &Value) {
6767
auto tensor = torch::tensor(Value, torch::kFloat32);
6868
tensor = tensor.reshape({1, Value.size()});
69-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
69+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
7070
}
7171

7272
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<double> &Value) {
7373
auto tensor = torch::tensor(Value, torch::kFloat64);
74-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
74+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
7575
}
7676

7777
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::string> &Value) {
@@ -81,21 +81,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
8181
flat_vec.push_back('\0'); // Null-terminate each string
8282
}
8383
auto tensor = torch::tensor(flat_vec, torch::kInt8);
84-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
84+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
8585
}
8686

8787
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<bool> &Value) {
8888
std::vector<uint8_t> bool_vec(Value.begin(), Value.end());
8989
auto tensor = torch::tensor(bool_vec, torch::kUInt8);
90-
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
90+
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
9191
}
9292

9393
// void PytorchSerDes::setRequest(void *Request) {
94-
// CompiledModel = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
94+
// CompiledModel = static_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
9595
// }
9696

9797
void PytorchSerDes::cleanDataStructures() {
98-
reinterpret_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
98+
static_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
9999
}
100100

101101
void *PytorchSerDes::deserializeUntyped(void *Data) {
@@ -104,7 +104,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
104104
}
105105

106106
// Assume Data is a pointer to a vector of tensors
107-
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<TensorVec *>(Data);
107+
std::vector<torch::Tensor> *serializedTensors = static_cast<TensorVec *>(Data);
108108

109109
if (serializedTensors->empty()) {
110110
return nullptr;
@@ -138,7 +138,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
138138

139139
void *PytorchSerDes::getSerializedData() {
140140
return this->ResponseVoid; // TODO - check
141-
// TensorVec serializedData = *reinterpret_cast<TensorVec*>(this->ReponseVoid);
141+
// TensorVec serializedData = *static_cast<TensorVec*>(this->ReponseVoid);
142142

143143
// // Allocate memory for the output and copy the serialized data
144144
// auto *output = new TensorVec(serializedData);
@@ -148,7 +148,7 @@ void *PytorchSerDes::getSerializedData() {
148148
template <typename T>
149149
std::vector<T> *PytorchSerDes::copyTensorToVect(void *serializedTensors) {
150150
auto *ret = new std::vector<T>();
151-
for (const auto &tensor : *reinterpret_cast<TensorVec*>(serializedTensors)) {
151+
for (const auto &tensor : *static_cast<TensorVec*>(serializedTensors)) {
152152
ret->insert(ret->end(), tensor.data_ptr<T>(), tensor.data_ptr<T>() + tensor.numel());
153153
}
154154
return ret;

0 commit comments

Comments
 (0)