diff --git a/src/base/top_k_top_p_sampler.h b/src/base/top_k_top_p_sampler.h new file mode 100644 index 00000000..085a2b93 --- /dev/null +++ b/src/base/top_k_top_p_sampler.h @@ -0,0 +1,78 @@ +#ifndef INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_ +#define INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_ + +#include +#include + +#include "data_type.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +// `TopKTopPSampler` samples token ids from 2D `logits` after optional rank and +// nucleus filtering. The name and tensor boundary follow vLLM's +// `TopKTopPSampler`; temperature scaling is intentionally handled by callers. +// The optional `k` and `p` tensors may be shaped as `[1]` or `[batch_size]`. +class TopKTopPSampler : public Operator { + public: + TopKTopPSampler(const Tensor logits, std::optional k, + std::optional p, Tensor out) + : batch_size_{logits.size(0)}, + vocab_size_{logits.size(1)}, + dtype_{logits.dtype()} { + assert(logits.ndim() == 2 && + "`TopKTopPSampler` requires 2D `[batch_size, vocab_size]` logits"); + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 || + dtype_ == DataType::kFloat32 || dtype_ == DataType::kFloat64) && + "`TopKTopPSampler` requires floating-point logits"); + assert(out.ndim() == 1 && + "`TopKTopPSampler` requires 1D `[batch_size]` output"); + assert(out.size(0) == batch_size_ && + "`TopKTopPSampler` requires output batch size to match logits"); + assert(out.dtype() == DataType::kInt32 && + "`TopKTopPSampler` requires int32 output"); + + ValidateK(k); + ValidateP(p); + } + + virtual void operator()(const Tensor logits, std::optional k, + std::optional p, Tensor out) const = 0; + + protected: + void ValidateK(std::optional k) const { + if (!k.has_value()) return; + + assert(k->ndim() == 1 && + "`TopKTopPSampler` requires `k` to be 1D when provided"); + assert((k->size(0) == 1 || k->size(0) == batch_size_) && + "`TopKTopPSampler` requires `k` shape [1] or [batch_size]"); + assert((k->dtype() == DataType::kInt32 || k->dtype() == DataType::kInt64) && + "`TopKTopPSampler` requires int32 or int64 `k`"); + } + + void ValidateP(std::optional p) const { + if (!p.has_value()) return; + + assert(p->ndim() == 1 && + "`TopKTopPSampler` requires `p` to be 1D when provided"); + assert((p->size(0) == 1 || p->size(0) == batch_size_) && + "`TopKTopPSampler` requires `p` shape [1] or [batch_size]"); + assert((p->dtype() == DataType::kFloat16 || + p->dtype() == DataType::kBFloat16 || + p->dtype() == DataType::kFloat32 || + p->dtype() == DataType::kFloat64) && + "`TopKTopPSampler` requires floating-point `p`"); + } + + Tensor::Size batch_size_{0}; + + Tensor::Size vocab_size_{0}; + + DataType dtype_; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_ diff --git a/src/native/ascend/ops/top_k_top_p_sampler/kernel.h b/src/native/ascend/ops/top_k_top_p_sampler/kernel.h new file mode 100644 index 00000000..79c06c3c --- /dev/null +++ b/src/native/ascend/ops/top_k_top_p_sampler/kernel.h @@ -0,0 +1,267 @@ +#ifndef INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_ +#define INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "aclnnop/aclnn_top_k_top_p_sample.h" +#include "base/top_k_top_p_sampler.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +template <> +class Operator + : public TopKTopPSampler { + public: + Operator(const Tensor logits, std::optional k, + std::optional p, Tensor out) + : TopKTopPSampler(logits, k, p, out) { + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16) && + "`TopKTopPSampler` Ascend ACLNN path requires float16 or bfloat16 " + "logits"); + assert(logits.IsContiguous() && + "`TopKTopPSampler` Ascend ACLNN path requires contiguous logits"); + assert(out.IsContiguous() && + "`TopKTopPSampler` Ascend ACLNN path requires contiguous output"); + ValidateHostTensor(k); + ValidateHostTensor(p); + + logits_cache_ = ascend::AclTensorCache(logits); + top_k_cache_ = ascend::AclTensorCache({static_cast(batch_size_)}, + ACL_INT32, nullptr); + top_p_cache_ = ascend::AclTensorCache({static_cast(batch_size_)}, + ascend::ToAclDtype(dtype_), nullptr); + selected_idx_cache_ = ascend::AclTensorCache( + {static_cast(batch_size_)}, ACL_INT64, nullptr); + selected_logits_cache_ = ascend::AclTensorCache( + {static_cast(batch_size_), static_cast(vocab_size_)}, + ACL_FLOAT, nullptr); + out_cache_ = ascend::AclTensorCache(out); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + logits_cache_.release(); + top_k_cache_.release(); + top_p_cache_.release(); + selected_idx_cache_.release(); + selected_logits_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor logits, std::optional k, + std::optional p, Tensor out) const override { + assert(logits.IsContiguous() && + "`TopKTopPSampler` Ascend ACLNN path requires contiguous logits"); + assert(out.IsContiguous() && + "`TopKTopPSampler` Ascend ACLNN path requires contiguous output"); + assert(IsGreedy(k) && + "`TopKTopPSampler` Ascend ACLNN path supports `top_k == 1` only"); + + auto stream = static_cast(stream_); + auto top_k_bytes = batch_size_ * kDataTypeToSize.at(DataType::kInt32); + auto top_p_bytes = batch_size_ * kDataTypeToSize.at(dtype_); + auto selected_idx_bytes = + batch_size_ * kDataTypeToSize.at(DataType::kInt64); + auto selected_logits_bytes = + batch_size_ * vocab_size_ * kDataTypeToSize.at(DataType::kFloat32); + + FillGreedyParams(p); + + auto& top_k_arena = ascend::GetWorkspacePool().Ensure( + stream, top_k_bytes, "top_k_top_p_sample_top_k"); + auto& top_p_arena = ascend::GetWorkspacePool().Ensure( + stream, top_p_bytes, "top_k_top_p_sample_top_p"); + auto ret = aclrtMemcpy(top_k_arena.buf, top_k_bytes, top_k_host_.data(), + top_k_bytes, ACL_MEMCPY_HOST_TO_DEVICE); + assert(ret == ACL_SUCCESS && + "`TopKTopPSampler`: copying `top_k` to Ascend failed"); + ret = aclrtMemcpy(top_p_arena.buf, top_p_bytes, top_p_host_.data(), + top_p_bytes, ACL_MEMCPY_HOST_TO_DEVICE); + assert(ret == ACL_SUCCESS && + "`TopKTopPSampler`: copying `top_p` to Ascend failed"); + + auto& selected_idx_arena = ascend::GetWorkspacePool().Ensure( + stream, selected_idx_bytes, "top_k_top_p_sample_idx"); + auto& selected_logits_arena = ascend::GetWorkspacePool().Ensure( + stream, selected_logits_bytes, "top_k_top_p_sample_logits"); + + auto t_logits = logits_cache_.get(const_cast(logits.data())); + auto t_top_k = top_k_cache_.get(top_k_arena.buf); + auto t_top_p = top_p_cache_.get(top_p_arena.buf); + auto t_selected_idx = selected_idx_cache_.get(selected_idx_arena.buf); + auto t_selected_logits = + selected_logits_cache_.get(selected_logits_arena.buf); + + if (!sample_exec_) { + ret = aclnnTopKTopPSampleGetWorkspaceSize( + t_logits, t_top_k, t_top_p, + /*qOptional=*/nullptr, /*eps=*/1e-8, /*isNeedLogits=*/false, + /*topKGuess=*/32, t_selected_idx, t_selected_logits, &sample_ws_size_, + &sample_exec_); + assert(ret == ACL_SUCCESS && + "`aclnnTopKTopPSampleGetWorkspaceSize` failed"); + aclSetAclOpExecutorRepeatable(sample_exec_); + } else { + aclSetInputTensorAddr(sample_exec_, 0, t_logits, + const_cast(logits.data())); + aclSetInputTensorAddr(sample_exec_, 1, t_top_k, top_k_arena.buf); + aclSetInputTensorAddr(sample_exec_, 2, t_top_p, top_p_arena.buf); + aclSetOutputTensorAddr(sample_exec_, 0, t_selected_idx, + selected_idx_arena.buf); + aclSetOutputTensorAddr(sample_exec_, 1, t_selected_logits, + selected_logits_arena.buf); + } + + auto& sample_ws_arena = ascend::GetWorkspacePool().Ensure( + stream, sample_ws_size_, "top_k_top_p_sample_workspace"); + ret = aclnnTopKTopPSample(sample_ws_arena.buf, sample_ws_size_, + sample_exec_, stream); + assert(ret == ACL_SUCCESS && "`aclnnTopKTopPSample` failed"); + + CastSelectedIdx(selected_idx_arena.buf, out); + } + + private: + void ValidateHostTensor(std::optional tensor) const { + if (!tensor.has_value()) return; + + assert(tensor->device().type() == Device::Type::kCpu && + "`TopKTopPSampler` Ascend path currently requires host-side " + "`k`/`p` tensors"); + assert(tensor->IsContiguous() && + "`TopKTopPSampler` Ascend path requires contiguous `k`/`p` " + "tensors"); + } + + bool IsGreedy(std::optional k) const { + if (!k.has_value()) return false; + + for (Tensor::Size row = 0; row < batch_size_; ++row) { + if (GetK(k, row) != 1) return false; + } + + return true; + } + + void CastSelectedIdx(void* selected_idx, Tensor out) const { + auto stream = static_cast(stream_); + auto t_selected_idx = selected_idx_cache_.get(selected_idx); + auto t_out = out_cache_.get(out.data()); + + if (!cast_exec_) { + auto ret = aclnnCastGetWorkspaceSize(t_selected_idx, ACL_INT32, t_out, + &cast_ws_size_, &cast_exec_); + assert(ret == ACL_SUCCESS && "`aclnnCastGetWorkspaceSize` failed"); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_selected_idx, selected_idx); + aclSetOutputTensorAddr(cast_exec_, 0, t_out, out.data()); + } + + auto& cast_ws_arena = ascend::GetWorkspacePool().Ensure( + stream, cast_ws_size_, "top_k_top_p_sample_cast_workspace"); + auto ret = aclnnCast(cast_ws_arena.buf, cast_ws_size_, cast_exec_, stream); + assert(ret == ACL_SUCCESS && "`aclnnCast` failed"); + } + + void FillGreedyParams(std::optional p) const { + top_k_host_.assign(batch_size_, 1); + top_p_host_.resize(batch_size_ * kDataTypeToSize.at(dtype_)); + + for (Tensor::Size row = 0; row < batch_size_; ++row) { + auto value = static_cast(GetP(p, row)); + auto* dst = top_p_host_.data() + row * kDataTypeToSize.at(dtype_); + + if (dtype_ == DataType::kFloat16) { + auto converted = Float16::FromFloat(value); + std::memcpy(dst, &converted, sizeof(converted)); + } else { + auto converted = BFloat16::FromFloat(value); + std::memcpy(dst, &converted, sizeof(converted)); + } + } + } + + int64_t GetK(std::optional k, Tensor::Size row) const { + if (!k.has_value()) return static_cast(vocab_size_); + + const auto offset = k->size(0) == 1 ? 0 : row; + int64_t value = 0; + if (k->dtype() == DataType::kInt32) { + value = static_cast(k->data())[offset]; + } else { + value = static_cast(k->data())[offset]; + } + + if (value <= 0) return static_cast(vocab_size_); + return std::min(value, static_cast(vocab_size_)); + } + + double GetP(std::optional p, Tensor::Size row) const { + if (!p.has_value()) return 1.0; + + const auto offset = p->size(0) == 1 ? 0 : row; + double value = 1.0; + switch (p->dtype()) { + case DataType::kFloat16: + value = static_cast(p->data())[offset].ToFloat(); + break; + case DataType::kBFloat16: + value = static_cast(p->data())[offset].ToFloat(); + break; + case DataType::kFloat32: + value = static_cast(p->data())[offset]; + break; + case DataType::kFloat64: + value = static_cast(p->data())[offset]; + break; + default: + assert(false && "`TopKTopPSampler` has unsupported `p` dtype"); + } + + if (value <= 0.0 || value > 1.0) return 1.0; + return value; + } + + mutable ascend::AclTensorCache logits_cache_; + + mutable ascend::AclTensorCache top_k_cache_; + + mutable ascend::AclTensorCache top_p_cache_; + + mutable ascend::AclTensorCache selected_idx_cache_; + + mutable ascend::AclTensorCache selected_logits_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable std::vector top_k_host_; + + mutable std::vector top_p_host_; + + mutable aclOpExecutor* sample_exec_ = nullptr; + + mutable uint64_t sample_ws_size_ = 0; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_ diff --git a/tests/test_top_k_top_p_sampler.py b/tests/test_top_k_top_p_sampler.py new file mode 100644 index 00000000..8d898e0e --- /dev/null +++ b/tests/test_top_k_top_p_sampler.py @@ -0,0 +1,78 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize("shape", ((1, 8), (3, 16))) +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) +def test_top_k_top_p_sampler( + shape, + dtype, + device, + implementation_index, +): + batch_size, vocab_size = shape + logits = torch.full(shape, -10.0, dtype=dtype, device=device) + + for i in range(batch_size): + logits[i, i % vocab_size] = 10.0 + + k = torch.ones((batch_size,), dtype=torch.int64, device="cpu") + p = torch.ones((batch_size,), dtype=torch.float32, device="cpu") + out = torch.empty((batch_size,), dtype=torch.int32, device=device) + + return Payload( + _top_k_top_p_sampler, + _torch_argmax, + (logits, k, p, out), + {"implementation_index": implementation_index}, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) +def test_top_k_top_p_sampler_optional_p( + dtype, + device, + implementation_index, +): + shape = (3, 16) + batch_size, vocab_size = shape + logits = torch.full(shape, -10.0, dtype=dtype, device=device) + + for i in range(batch_size): + logits[i, (i + 1) % vocab_size] = 10.0 + + k = torch.ones((1,), dtype=torch.int64, device="cpu") + out = torch.empty((batch_size,), dtype=torch.int32, device=device) + + return Payload( + _top_k_top_p_sampler, + _torch_argmax, + (logits, k, None, out), + {"implementation_index": implementation_index}, + ) + + +def _top_k_top_p_sampler(logits, k, p, out, *, implementation_index): + infini.ops.top_k_top_p_sampler( + logits, + k, + p, + out, + stream=get_stream(logits.device), + implementation_index=implementation_index, + ) + + return out + + +def _torch_argmax(logits, k, p, out, *, implementation_index): + del k, p, implementation_index + + out.copy_(torch.argmax(logits, dim=-1).to(torch.int32)) + + return out