-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathgraph.hpp
More file actions
97 lines (80 loc) · 4.04 KB
/
graph.hpp
File metadata and controls
97 lines (80 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#pragma once
#include <memory>
#include <vector>
#include "../tensor.hpp"
namespace infinicore::graph {
// Forward declarations
class GraphManager;
class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
};
class GraphOperator {
public:
virtual void run() const = 0;
virtual ~GraphOperator() = default;
};
class DispatchableGraphOperator : public GraphOperator {
public:
void run() const override;
~DispatchableGraphOperator() override;
protected:
using run_schema = void (*)(void *);
using cleanup_schema = void (*)(void **);
void *planned_meta_;
run_schema runner_;
cleanup_schema deleter_;
};
class Graph {
public:
Graph() = default;
~Graph() = default;
void run() const;
protected:
void add_operator(std::shared_ptr<GraphOperator> op);
std::vector<std::shared_ptr<GraphOperator>> op_list_;
friend class GraphManager;
};
} // namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::DispatchableGraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
static common::OpDispatcher<run_schema> &run_dispatcher(); \
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
__OP_NAME__(__VA_ARGS__); \
static void execute(__VA_ARGS__); \
};
#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
return dispatcher_; \
}
#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(___op); \
} else { \
___op->run(); \
}
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
static bool registered = []() { \
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
return true; \
}();