From 28a55a83d4f8839a7b631243e5f82dc648d223f3 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 12 May 2026 16:14:45 +0800 Subject: [PATCH] feat(ascend): add embedding operator --- src/base/embedding.h | 68 ++++++++++++++++++++ src/native/ascend/ops/embedding/kernel.h | 80 ++++++++++++++++++++++++ tests/test_embedding.py | 69 ++++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 src/base/embedding.h create mode 100644 src/native/ascend/ops/embedding/kernel.h create mode 100644 tests/test_embedding.py diff --git a/src/base/embedding.h b/src/base/embedding.h new file mode 100644 index 000000000..cb1caac1e --- /dev/null +++ b/src/base/embedding.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_BASE_EMBEDDING_H_ +#define INFINI_OPS_BASE_EMBEDDING_H_ + +#include + +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Embedding performs a token embedding lookup. +// +// Interface follows the inference-time vLLM/PyTorch convention: +// `out = weight[input_ids]`. +// +// The input layout is: +// `input_ids`: Any shape, `int32` or `int64`. +// `weight`: `[vocab_size, hidden_size]`. +// `out`: `input_ids.shape + [hidden_size]`. +// +// This is the inference subset of `torch.nn.functional.embedding`; options +// such as `padding_idx`, `max_norm`, `scale_grad_by_freq`, and `sparse` are +// intentionally not part of this operator. +class Embedding : public Operator { + public: + Embedding(const Tensor input_ids, const Tensor weight, Tensor out) + : num_tokens_{input_ids.numel()}, + vocab_size_{weight.size(0)}, + hidden_size_{weight.size(1)}, + input_dtype_{input_ids.dtype()}, + weight_dtype_{weight.dtype()} { + assert((input_dtype_ == DataType::kInt32 || + input_dtype_ == DataType::kInt64) && + "`Embedding` requires `input_ids` to be `int32` or `int64`"); + assert( + weight.ndim() == 2 && + "`Embedding` requires `weight` to be 2D `[vocab_size, hidden_size]`"); + assert(out.dtype() == weight.dtype() && + "`Embedding` requires `out` and `weight` to have the same dtype"); + assert(out.ndim() == input_ids.ndim() + 1 && + "`Embedding` requires `out.ndim == input_ids.ndim + 1`"); + assert(out.size(-1) == hidden_size_ && + "`Embedding` requires `out.shape[-1] == weight.shape[-1]`"); + + for (std::size_t i = 0; i < input_ids.ndim(); ++i) { + assert(out.size(i) == input_ids.size(i) && + "`Embedding` requires `out` prefix shape to match `input_ids`"); + } + } + + virtual void operator()(const Tensor input_ids, const Tensor weight, + Tensor out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + Tensor::Size vocab_size_{0}; + + Tensor::Size hidden_size_{0}; + + const DataType input_dtype_; + + const DataType weight_dtype_; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_EMBEDDING_H_ diff --git a/src/native/ascend/ops/embedding/kernel.h b/src/native/ascend/ops/embedding/kernel.h new file mode 100644 index 000000000..4b7b0524a --- /dev/null +++ b/src/native/ascend/ops/embedding/kernel.h @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_embedding.h" +#include "base/embedding.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Embedding { + public: + Operator(const Tensor input_ids, const Tensor weight, Tensor out) + : Embedding(input_ids, weight, out), + input_ids_cache_(input_ids), + weight_cache_(weight), + out_cache_(out) { + assert((weight_dtype_ == DataType::kFloat16 || + weight_dtype_ == DataType::kBFloat16 || + weight_dtype_ == DataType::kFloat32) && + "`Embedding`: Ascend path supports `float16`, `bfloat16`, and " + "`float32` weights"); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + input_ids_cache_.release(); + weight_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor input_ids, const Tensor weight, + Tensor out) const override { + auto stream = static_cast(stream_); + + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_input_ids = + input_ids_cache_.get(const_cast(input_ids.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + auto ret = aclnnEmbeddingGetWorkspaceSize(t_weight, t_input_ids, t_out, + &ws_size_, &executor_); + assert(ret == ACL_SUCCESS && "`aclnnEmbeddingGetWorkspaceSize` failed"); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_weight, + const_cast(weight.data())); + aclSetInputTensorAddr(executor_, 1, t_input_ids, + const_cast(input_ids.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + auto ret = aclnnEmbedding(arena.buf, ws_size_, executor_, stream); + assert(ret == ACL_SUCCESS && "`aclnnEmbedding` failed"); + } + + private: + mutable ascend::AclTensorCache input_ids_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_ diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 000000000..2403b0745 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,69 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "input_shape, vocab_size, hidden_size", + ( + ((5,), 17, 8), + ((2, 3), 23, 16), + ), +) +@pytest.mark.parametrize("index_dtype", (torch.int32, torch.int64)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 0.0, 0.0), + (torch.float16, 0.0, 0.0), + (torch.bfloat16, 0.0, 0.0), + ), +) +def test_embedding( + input_shape, + vocab_size, + hidden_size, + index_dtype, + implementation_index, + dtype, + device, + rtol, + atol, +): + input_ids = torch.randint( + 0, vocab_size, input_shape, dtype=index_dtype, device=device + ) + weight = torch.randn((vocab_size, hidden_size), dtype=dtype, device=device) + out = torch.empty((*input_shape, hidden_size), dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _embedding( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_embedding, + (input_ids, weight, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _embedding(input_ids, weight, out, *, implementation_index=0): + infini.ops.embedding( + input_ids, + weight, + out, + implementation_index=implementation_index, + stream=get_stream(input_ids.device), + ) + + return out + + +def _ref_embedding(input_ids, weight, out): + del out + + return torch.nn.functional.embedding(input_ids.long(), weight)