|
1 | | -#include "../../utils.hpp" |
2 | | -#include "infinicore/common/hash.hpp" |
3 | | -#include "infinicore/ops/common/cache.hpp" |
| 1 | +#include "../infiniop_impl.hpp" |
4 | 2 | #include "infinicore/ops/embedding.hpp" |
5 | | -#include <infiniop.h> |
6 | 3 |
|
7 | 4 | namespace infinicore::op::embedding_impl::infiniop { |
8 | 5 |
|
9 | | -thread_local common::OpCache<size_t, infiniopEmbeddingDescriptor_t> caches( |
10 | | - 100, // capacity |
11 | | - [](infiniopEmbeddingDescriptor_t &desc) { |
12 | | - if (desc != nullptr) { |
13 | | - INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); |
14 | | - desc = nullptr; |
15 | | - } |
16 | | - }); |
| 6 | +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100); |
17 | 7 |
|
18 | | -void calculate(Tensor out, Tensor input, Tensor weight) { |
| 8 | +struct PlannedMeta { |
| 9 | + std::shared_ptr<Descriptor> descriptor; |
| 10 | + graph::GraphTensor out, input, weight; |
| 11 | +}; |
| 12 | + |
| 13 | +void *plan(Tensor out, Tensor input, Tensor weight) { |
19 | 14 | size_t seed = hash_combine(out, input, weight); |
20 | 15 |
|
21 | | - auto device = context::getDevice(); |
22 | | - auto &cache = caches.getCache(device); |
| 16 | + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( |
| 17 | + Descriptor, descriptor, Embedding, |
| 18 | + seed, out->desc(), input->desc(), weight->desc()); |
| 19 | + |
| 20 | + auto planned = new PlannedMeta{ |
| 21 | + descriptor, |
| 22 | + graph::GraphTensor(out), |
| 23 | + graph::GraphTensor(input), |
| 24 | + graph::GraphTensor(weight)}; |
23 | 25 |
|
24 | | - auto desc_opt = cache.get(seed); |
25 | | - infiniopEmbeddingDescriptor_t desc = nullptr; |
| 26 | + return planned; |
| 27 | +} |
26 | 28 |
|
27 | | - if (!desc_opt) { |
28 | | - INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( |
29 | | - context::getInfiniopHandle(device), &desc, |
30 | | - out->desc(), input->desc(), weight->desc())); |
31 | | - cache.put(seed, desc); |
32 | | - } else { |
33 | | - desc = *desc_opt; |
34 | | - } |
| 29 | +void run(void *planned_meta) { |
| 30 | + auto planned = reinterpret_cast<PlannedMeta *>(planned_meta); |
35 | 31 |
|
36 | 32 | INFINICORE_CHECK_ERROR(infiniopEmbedding( |
37 | | - desc, |
38 | | - out->data(), |
39 | | - input->data(), |
40 | | - weight->data(), |
41 | | - context::getStream())); |
| 33 | + planned->descriptor->desc, |
| 34 | + planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream())); |
| 35 | +} |
| 36 | + |
| 37 | +void cleanup(void **planned_meta_ptr) { |
| 38 | + delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr); |
| 39 | + *planned_meta_ptr = nullptr; |
42 | 40 | } |
43 | 41 |
|
44 | | -static bool registered = []() { |
45 | | - Embedding::dispatcher().registerAll(&calculate, false); |
46 | | - return true; |
47 | | -}(); |
| 42 | +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup); |
48 | 43 |
|
49 | 44 | } // namespace infinicore::op::embedding_impl::infiniop |
0 commit comments