We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1d113db commit 03f067eCopy full SHA for 03f067e
1 file changed
src/operators/concat.cc
@@ -17,6 +17,17 @@ optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) {
17
// TODO:修改 dims,返回正确的 concat 后的 shape
18
// REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13
19
// =================================== 作业 ===================================
20
+ size_t sum = 0;
21
+ for (const auto &t : inputs) {
22
+ IT_ASSERT(t->getRank() == rank);
23
+ const auto &d = t->getDims();
24
+ for (size_t i = 0; i < rank; ++i) {
25
+ if (i != static_cast<size_t>(dim))
26
+ IT_ASSERT(d[i] == dims[i]);
27
+ }
28
+ sum += static_cast<size_t>(d[dim]);
29
30
+ dims[dim] = static_cast<int>(sum);
31
32
return {{dims}};
33
}
0 commit comments