Skip to content

Commit 736a9fd

Browse files
issue/810 support more ops as graph op
1 parent 7c97894 commit 736a9fd

38 files changed

Lines changed: 1073 additions & 680 deletions

include/infinicore/graph/graph.hpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ class GraphTensor : public Tensor {
1515
};
1616

1717
class GraphOperator {
18+
public:
19+
virtual void run() const = 0;
20+
virtual ~GraphOperator() = default;
21+
};
1822

23+
class DispatchableGraphOperator : public GraphOperator {
1924
public:
20-
void run() const;
21-
~GraphOperator();
25+
void run() const override;
26+
~DispatchableGraphOperator() override;
2227

2328
protected:
2429
using run_schema = void (*)(void *);
@@ -45,7 +50,7 @@ class Graph {
4550
} // namespace infinicore::graph
4651

4752
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
48-
class __OP_NAME__ : public graph::GraphOperator { \
53+
class __OP_NAME__ : public graph::DispatchableGraphOperator { \
4954
public: \
5055
using schema = void (*)(__VA_ARGS__); \
5156
using plan_schema = void *(*)(__VA_ARGS__); \
@@ -75,12 +80,12 @@ class Graph {
7580
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
7681
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
7782

78-
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
79-
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
80-
if (context::isGraphRecording()) { \
81-
context::addGraphOperator(op); \
82-
} else { \
83-
op->run(); \
83+
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
84+
auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
85+
if (context::isGraphRecording()) { \
86+
context::addGraphOperator(___op); \
87+
} else { \
88+
___op->run(); \
8489
}
8590

8691
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \

include/infinicore/ops/add.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
7-
class Add {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor c, Tensor a, Tensor b);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
138

14-
Tensor add(Tensor a, Tensor b);
15-
void add_(Tensor c, Tensor a, Tensor b);
16-
Tensor operator+(Tensor a, Tensor b);
9+
INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &);
10+
11+
Tensor add(const Tensor &a, const Tensor &b);
12+
void add_(Tensor c, const Tensor &a, const Tensor &b);
13+
1714
} // namespace infinicore::op
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
7-
class CausalSoftmax {
8-
public:
9-
using schema = void (*)(Tensor, Tensor);
10-
static void execute(Tensor output, Tensor input);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
138

14-
Tensor causal_softmax(Tensor input);
15-
void causal_softmax_(Tensor output, Tensor input);
9+
INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &);
10+
11+
Tensor causal_softmax(const Tensor &input);
12+
void causal_softmax_(Tensor output, const Tensor &input);
13+
1614
} // namespace infinicore::op
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "../../device.hpp"
4+
#include "../../graph/graph.hpp"
5+
#include "../common/op.hpp"
6+
7+
#include <infiniccl.h>
8+
9+
namespace infinicore::op::distributed {
10+
class AllReduce : public graph::GraphOperator {
11+
public:
12+
AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
13+
~AllReduce();
14+
void run() const override;
15+
static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
16+
17+
private:
18+
void *planned_meta_;
19+
};
20+
21+
Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
22+
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
23+
24+
} // namespace infinicore::op::distributed

include/infinicore/ops/gemm.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
namespace infinicore::op {
88

9-
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
9+
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float);
1010

11-
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
12-
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
11+
Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f);
12+
void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta);
1313

1414
} // namespace infinicore::op

include/infinicore/ops/mul.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
7-
class Mul {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor c, Tensor a, Tensor b);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
138

14-
Tensor mul(Tensor a, Tensor b);
15-
void mul_(Tensor c, Tensor a, Tensor b);
9+
INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);
10+
11+
Tensor mul(const Tensor &a, const Tensor &b);
12+
void mul_(Tensor c, const Tensor &a, const Tensor &b);
13+
1614
} // namespace infinicore::op
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56
#include <optional>
67

78
namespace infinicore::op {
89

9-
class PagedAttention {
10-
public:
11-
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
12-
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
13-
static common::OpDispatcher<schema> &dispatcher();
14-
};
10+
INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);
11+
12+
Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
13+
const Tensor &block_tables, const Tensor &kv_lens,
14+
std::optional<Tensor> alibi_slopes, float scale);
15+
16+
void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
17+
const Tensor &block_tables, const Tensor &kv_lens,
18+
std::optional<Tensor> alibi_slopes, float scale);
1519

16-
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
17-
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
1820
} // namespace infinicore::op
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class PagedCaching {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
11-
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
1410

15-
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
11+
void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);
1612

1713
} // namespace infinicore::op
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
7-
class Rearrange {
8-
public:
9-
using schema = void (*)(Tensor, Tensor);
10-
static void execute(Tensor y, Tensor x);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
138

14-
Tensor rearrange(Tensor x);
15-
void rearrange_(Tensor y, Tensor x);
9+
INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);
10+
11+
Tensor rearrange(const Tensor &x);
12+
void rearrange_(Tensor y, const Tensor &x);
13+
1614
} // namespace infinicore::op
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
7-
class RMSNorm {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor, float);
10-
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
138

14-
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
15-
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
9+
INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);
10+
11+
Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
12+
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
13+
1614
} // namespace infinicore::op

0 commit comments

Comments
 (0)