Skip to content

Commit 4866ddf

Browse files
committed
issue/900 - adapt to graph and adjust test script
1 parent 5fd934c commit 4866ddf

7 files changed

Lines changed: 328 additions & 813 deletions

File tree

include/infinicore/ops/embedding.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
#pragma once
22

3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
35
#include "common/op.hpp"
46

57
namespace infinicore::op {
68

7-
class Embedding {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor out, Tensor input, Tensor weight);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
9+
INFINICORE_GRAPH_OP_CLASS(Embedding, Tensor, Tensor, Tensor);
1310

1411
Tensor embedding(Tensor input, Tensor weight);
1512
void embedding_(Tensor out, Tensor input, Tensor weight);

src/infinicore/ops/embedding/embedding.cc

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,15 @@
55
#include <stdexcept>
66

77
namespace infinicore::op {
8+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);
89

9-
common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
10-
static common::OpDispatcher<Embedding::schema> dispatcher_;
11-
return dispatcher_;
10+
Embedding::Embedding(Tensor out, Tensor input, Tensor weight) {
11+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
12+
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, input, weight);
1213
}
1314

1415
void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
15-
// Check that all tensors are on the same device
16-
// This is critical: if input is on CPU while out/weight are on GPU,
17-
// passing CPU pointer to CUDA kernel will cause memory access errors
18-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
19-
20-
// Set device context
21-
infinicore::context::setDevice(out->device());
22-
23-
// Use dispatcher to lookup kernel (infiniop implementation)
24-
dispatcher().lookup(out->device().getType())(out, input, weight);
16+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Embedding, out, input, weight);
2517
}
2618

2719
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,44 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
3-
#include "infinicore/ops/common/cache.hpp"
1+
#include "../infiniop_impl.hpp"
42
#include "infinicore/ops/embedding.hpp"
5-
#include <infiniop.h>
63

74
namespace infinicore::op::embedding_impl::infiniop {
85

9-
thread_local common::OpCache<size_t, infiniopEmbeddingDescriptor_t> caches(
10-
100, // capacity
11-
[](infiniopEmbeddingDescriptor_t &desc) {
12-
if (desc != nullptr) {
13-
INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc));
14-
desc = nullptr;
15-
}
16-
});
6+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100);
177

18-
void calculate(Tensor out, Tensor input, Tensor weight) {
8+
struct PlannedMeta {
9+
std::shared_ptr<Descriptor> descriptor;
10+
graph::GraphTensor out, input, weight;
11+
};
12+
13+
void *plan(Tensor out, Tensor input, Tensor weight) {
1914
size_t seed = hash_combine(out, input, weight);
2015

21-
auto device = context::getDevice();
22-
auto &cache = caches.getCache(device);
16+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
17+
Descriptor, descriptor, Embedding,
18+
seed, out->desc(), input->desc(), weight->desc());
19+
20+
auto planned = new PlannedMeta{
21+
descriptor,
22+
graph::GraphTensor(out),
23+
graph::GraphTensor(input),
24+
graph::GraphTensor(weight)};
2325

24-
auto desc_opt = cache.get(seed);
25-
infiniopEmbeddingDescriptor_t desc = nullptr;
26+
return planned;
27+
}
2628

27-
if (!desc_opt) {
28-
INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor(
29-
context::getInfiniopHandle(device), &desc,
30-
out->desc(), input->desc(), weight->desc()));
31-
cache.put(seed, desc);
32-
} else {
33-
desc = *desc_opt;
34-
}
29+
void run(void *planned_meta) {
30+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
3531

3632
INFINICORE_CHECK_ERROR(infiniopEmbedding(
37-
desc,
38-
out->data(),
39-
input->data(),
40-
weight->data(),
41-
context::getStream()));
33+
planned->descriptor->desc,
34+
planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream()));
35+
}
36+
37+
void cleanup(void **planned_meta_ptr) {
38+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
39+
*planned_meta_ptr = nullptr;
4240
}
4341

44-
static bool registered = []() {
45-
Embedding::dispatcher().registerAll(&calculate, false);
46-
return true;
47-
}();
42+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup);
4843

4944
} // namespace infinicore::op::embedding_impl::infiniop

0 commit comments

Comments
 (0)