forked from LearningInfiniTensor/TinyInfiniTensor
-
Notifications
You must be signed in to change notification settings - Fork 98
Expand file tree
/
Copy pathgraph.h
More file actions
184 lines (157 loc) · 5.36 KB
/
graph.h
File metadata and controls
184 lines (157 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#pragma once
#include "core/allocator.h"
#include "core/operator.h"
#include "core/tensor.h"
#include "operators/transpose.h"
#include <algorithm>
#include <cstdint>
namespace infini
{
// 前向声明
class MatmulObj;
class GraphObj : public Object
{
protected:
Runtime runtime;
TensorVec tensors;
OpVec ops;
Allocator allocator;
public:
explicit GraphObj(Runtime runtime)
: runtime(runtime), allocator(runtime), sorted(false){};
string toString() const override;
Runtime getRuntime() const { return runtime; }
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
void removeOperator(Operator op)
{
// 清理操作符与张量的连接关系
for (auto& input : op->getInputs()) {
if (input) {
input->removeTarget(op);
}
}
for (auto& output : op->getOutputs()) {
if (output) {
output->setSource(nullptr);
}
}
// 清理操作符之间的连接关系
// 从所有前驱操作符中移除对当前操作符的引用
auto predecessors = op->getPredecessors();
for (const auto& pred : predecessors) {
if (pred) {
pred->removeSuccessors(op);
}
}
// 从所有后继操作符中移除对当前操作符的引用
auto successors = op->getSuccessors();
for (const auto& succ : successors) {
if (succ) {
succ->removePredecessors(op);
}
}
// 从操作符列表中删除
auto it = std::find(ops.begin(), ops.end(), op);
if (it != ops.end())
ops.erase(it);
}
void removeTensor(Tensor tensor)
{
auto it = std::find(tensors.begin(), tensors.end(), tensor);
if (it != tensors.end())
tensors.erase(it);
}
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
Tensor getTensor(int) const;
/**
* @brief Sort the nodes in topological order.
* It returns true if the sorting is successful.
* Otherwise false is returned, means that there are rings in the graph,
* so the topological sorting fails.
*/
bool topo_sort();
void optimize();
void shape_infer();
void dataMalloc();
/**
* @brief Add an operator and create its outputs. Output tensor arguments
* should be empty Refs (e.g., nullptr).
*/
template <typename T, typename... Args>
Ref<T> addOp(Args &&...args)
{
Ref<T> op = infini::make_ref<T>(this, std::forward<Args>(args)...);
addOperatorAndConnect(op);
return op;
}
/**
* @brief Add an operator with its outputs specified.
*/
template <typename T, typename... Args>
Ref<T> addOpWithOutputs(Args &&...args)
{
Ref<T> op = infini::make_ref<T>(nullptr, std::forward<Args>(args)...);
addOperatorAndConnect(op);
return op;
}
/**
* @brief Gets input tensors of this graph.
*/
inline TensorVec getInputs() const
{
TensorVec ret;
for (const auto &t : tensors)
if (!t->getSource())
ret.emplace_back(t);
return ret;
}
/**
* @brief Gets output tensors of this graph.
*/
inline TensorVec getOutputs() const
{
TensorVec ret;
for (const auto &t : tensors)
if (t->getTargets().empty())
ret.emplace_back(t);
return ret;
}
bool checkValid() const;
private:
/**
* @brief Add reverse connections and Op relationship in ctor.
*/
void addOperatorAndConnect(const Operator &op);
/**
* @brief If the nodes is sorted in topological order.
*/
bool sorted;
/**
* @brief Add check function for inverse transpose
*/
bool areInverseTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2);
/**
* @brief Add check function for same transpose
*/
bool areSameTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2);
/**
* @brief Check if transpose swaps last two dimensions
*/
bool isLastTwoDimsSwap(const TransposeObj *transpose);
/**
* @brief Merge transpose into matmul operator
*/
void mergeTransposeToMatmul(const Operator& transpose, const Operator& matmul);
/**
* @brief Reconnect graph after removing operators
*/
void reconnectGraph(const Operator& op1, const Operator& op2);
/**
* @brief Clean up unused tensors
*/
void cleanupUnusedTensors();
};
} // namespace infini