forked from LearningInfiniTensor/TinyInfiniTensor
-
Notifications
You must be signed in to change notification settings - Fork 98
Expand file tree
/
Copy pathmatmul.cc
More file actions
66 lines (58 loc) · 2.29 KB
/
matmul.cc
File metadata and controls
66 lines (58 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "operators/matmul.h"
#include "utils/operator_utils.h"
namespace infini
{
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB)
: OperatorObj(OpType::MatMul, TensorVec{A, B}, {C}),
transA(transA), transB(transB)
{
IT_ASSERT(checkValid(graph));
}
string MatmulObj::toString() const
{
std::ostringstream os;
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B]")
<< ",A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ",mnk=[" << m << "," << n << "," << k << "])";
return os.str();
}
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs)
{
// =================================== 作业 ===================================
// TODO:返回经过 matmul 操作后的 shape
// REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm
// =================================== 作业 ===================================
//先拿到两个维度
//这里是多维的矩阵乘法,以最后两维度作为相乘的维度
const auto A = inputs[0];
const auto B = inputs[1];
auto shapeA = A->getDims();
auto shapeB = B->getDims();
int rankA = shapeA.size();
int rankB = shapeB.size();
//检查是否专职
if (this->transA && rankA >= 2) {
std::swap(shapeA[rankA - 1], shapeA[rankA - 2]);
}
if (this->transB && rankB >= 2) {
std::swap(shapeB[rankB - 1], shapeB[rankB - 2]);
}
//一般是M,N * N,K = M,K
int M = shapeA[rankA - 2];
int K_A = shapeA[rankA - 1];
int K_B = shapeB[rankB - 2];
int N = shapeB[rankB - 1];
//判断形状对不对
IT_ASSERT(K_A == K_B);
//其他维度需要广播
Shape batchA(shapeA.begin(), shapeA.end() - 2);
Shape batchB(shapeB.begin(), shapeB.end() - 2);
Shape batchOut = infer_broadcast(batchA, batchB);
Shape outputShape = batchOut;
outputShape.push_back(M);
outputShape.push_back(N);
return vector<Shape>{outputShape};
}
} // namespace infini