forked from LearningInfiniTensor/TinyInfiniTensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconcat.cc
More file actions
53 lines (49 loc) · 1.69 KB
/
concat.cc
File metadata and controls
53 lines (49 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include "operators/concat.h"
#include "utils/operator_utils.h"
namespace infini {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
: OperatorObj(OpType::Concat, inputs, {output}) {
int rank = inputs[0]->getRank();
dim = get_real_axis(_dim, rank);
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) {
Shape dims = inputs[0]->getDims(); // 数组的 shape
auto rank = inputs[0]->getRank();
// =================================== 作业 ===================================
// TODO:修改 dims,返回正确的 concat 后的 shape
// REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13
// =================================== 作业 ===================================
if(inputs.size() == 0) {
return std::nullopt;
}
for(auto input: inputs){
if(input->getDims().size() != rank)
return std::nullopt;
}
vector<int> res(rank, 0);
for(auto input: inputs){
for(size_t i = 0; i < rank; i++){
if(i == size_t(dim)){
res[i] += input->getDims()[i];
}else if (i != size_t(dim)){
res[i] = input->getDims()[i];
}
}
}
return {{res}};
}
std::string ConcatObj::toString() const {
std::ostringstream os;
os << "Concat[" << getGuid() << "]";
os << "(";
for (auto input : inputs)
os << vecToString(input->getDims()) << ",";
os << "dim=" << dim << ",";
os << "input=";
for (auto input : inputs)
os << input->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
} // namespace infini