From cd01e91090afef9f3907621656428c61bdb69d66 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 30 Jun 2026 15:09:32 +0800 Subject: [PATCH] feat(ascend): add `mha_fwd_kvcache` operator --- src/base/mha_fwd_kvcache.h | 215 ++++++++++++++++++ src/native/ascend/ops/fia_common_.h | 157 +++++++++++++ src/native/ascend/ops/graph_cleanup_.h | 60 +++++ .../ascend/ops/mha_fwd_kvcache/kernel.h | 189 +++++++++++++++ tests/test_mha_fwd_kvcache.py | 189 +++++++++++++++ 5 files changed, 810 insertions(+) create mode 100644 src/base/mha_fwd_kvcache.h create mode 100644 src/native/ascend/ops/fia_common_.h create mode 100644 src/native/ascend/ops/graph_cleanup_.h create mode 100644 src/native/ascend/ops/mha_fwd_kvcache/kernel.h create mode 100644 tests/test_mha_fwd_kvcache.py diff --git a/src/base/mha_fwd_kvcache.h b/src/base/mha_fwd_kvcache.h new file mode 100644 index 000000000..eb1112574 --- /dev/null +++ b/src/base/mha_fwd_kvcache.h @@ -0,0 +1,215 @@ +#ifndef INFINI_OPS_BASE_MHA_FWD_KVCACHE_H_ +#define INFINI_OPS_BASE_MHA_FWD_KVCACHE_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// FlashAttention-compatible forward attention over an existing KV cache. +// +// Layout follows `flash_attn` `fwd_kvcache`: +// `q`: `[batch, seqlen_q, num_heads, head_size]`. +// `kcache` / `vcache`: `[batch_cache, seqlen_k, num_kv_heads, head_size]` +// or paged `[num_blocks, page_block_size, num_kv_heads, head_size]` when +// `block_table` is supplied. +// +// InfiniOps is an in-place operator API, so `out` must be supplied even +// though FlashAttention can allocate it when omitted. +// This signature intentionally keeps `out` near the FlashAttention argument +// position for compatibility with existing callers; this is an exception to +// the normal InfiniOps output-last convention. +class MhaFwdKvcache : public Operator { + public: + MhaFwdKvcache( + const Tensor q, const Tensor kcache, const Tensor vcache, + std::optional k, std::optional v, + std::optional seqlens_k, std::optional rotary_cos, + std::optional rotary_sin, std::optional cache_batch_idx, + std::optional leftpad_k, std::optional block_table, + std::optional alibi_slopes, Tensor out, float softmax_scale, + bool is_causal, int64_t window_size_left, int64_t window_size_right, + float softcap, bool is_rotary_interleaved, int64_t num_splits = 0) + : batch_size_{q.size(0)}, + seqlen_q_{q.size(1)}, + num_heads_{q.size(2)}, + head_size_{q.size(3)}, + num_kv_heads_{kcache.size(2)}, + page_block_size_{block_table.has_value() ? kcache.size(1) : 0}, + softmax_scale_{softmax_scale}, + is_causal_{is_causal}, + window_size_left_{window_size_left}, + window_size_right_{window_size_right}, + softcap_{softcap}, + is_rotary_interleaved_{is_rotary_interleaved}, + num_splits_{num_splits}, + q_dtype_{q.dtype()}, + has_k_{k.has_value()}, + has_v_{v.has_value()}, + has_seqlens_k_{seqlens_k.has_value()}, + has_rotary_cos_{rotary_cos.has_value()}, + has_rotary_sin_{rotary_sin.has_value()}, + has_cache_batch_idx_{cache_batch_idx.has_value()}, + has_leftpad_k_{leftpad_k.has_value()}, + has_block_table_{block_table.has_value()}, + has_alibi_slopes_{alibi_slopes.has_value()} { + assert(q.ndim() == 4 && + "`MhaFwdKvcache` requires `q` to be `[batch, seq, heads, dim]`"); + assert(kcache.ndim() == 4 && vcache.ndim() == 4 && + "`MhaFwdKvcache` requires `kcache` / `vcache` to be 4D."); + assert(kcache.shape() == vcache.shape() && + "`MhaFwdKvcache` requires `kcache` and `vcache` same shape"); + assert(kcache.dtype() == q.dtype() && vcache.dtype() == q.dtype() && + "`MhaFwdKvcache` requires `q`, `kcache`, and `vcache` same dtype"); + assert(out.dtype() == q.dtype() && + "`MhaFwdKvcache` requires `out` to have same dtype as `q`."); + assert(kcache.size(3) == head_size_ && + "`MhaFwdKvcache` requires cache head dim to match `q`."); + assert(q.stride(-1) == 1 && kcache.stride(-1) == 1 && + vcache.stride(-1) == 1 && + "`MhaFwdKvcache` requires contiguous last dimension"); + assert(num_heads_ % num_kv_heads_ == 0 && + "`MhaFwdKvcache` requires `num_heads` divisible by `num_kv_heads`"); + assert(head_size_ <= 256 && + "`MhaFwdKvcache` supports `head_size` up to 256"); + assert(out.shape() == q.shape() && + "`MhaFwdKvcache` requires `out` to match `q` shape"); + assert(out.stride(-1) == 1 && + "`MhaFwdKvcache` requires `out` contiguous last dimension"); + assert(std::isfinite(softmax_scale_) && + "`MhaFwdKvcache` requires finite `softmax_scale`."); + + if (k.has_value() || v.has_value()) { + assert(k.has_value() && v.has_value() && + "`MhaFwdKvcache` requires `k` and `v` together."); + if (k.has_value() && v.has_value()) { + assert(k->ndim() == 4 && v->ndim() == 4 && + "`MhaFwdKvcache` requires appended `k` / `v` to be 4D."); + assert(k->shape() == v->shape() && + "`MhaFwdKvcache` requires appended `k` and `v` same shape."); + assert(k->dtype() == q.dtype() && v->dtype() == q.dtype() && + "`MhaFwdKvcache` requires appended `k` / `v` same dtype as " + "`q`."); + assert(k->size(0) == batch_size_ && k->size(2) == num_kv_heads_ && + k->size(3) == head_size_ && + "`MhaFwdKvcache` requires appended `k` / `v` to match cache " + "batch, heads, and dim."); + assert(k->stride(-1) == 1 && v->stride(-1) == 1 && + "`MhaFwdKvcache` requires appended `k` / `v` contiguous last " + "dimension."); + } + } + + if (seqlens_k.has_value()) { + assert(seqlens_k->ndim() == 1 && + "`MhaFwdKvcache` requires `seqlens_k` to be 1D."); + assert((seqlens_k->dtype() == DataType::kInt32 || + seqlens_k->dtype() == DataType::kInt64) && + "`MhaFwdKvcache` requires `seqlens_k` to be `int32` or `int64`."); + } + + if (block_table.has_value()) { + assert(block_table->ndim() == 2 && + "`MhaFwdKvcache` requires `block_table` to be 2D."); + assert(block_table->dtype() == DataType::kInt32 && + "`MhaFwdKvcache` requires `block_table` to be `int32`."); + } + + if (cache_batch_idx.has_value()) { + assert(cache_batch_idx->ndim() == 1 && + "`MhaFwdKvcache` requires `cache_batch_idx` to be 1D."); + assert(cache_batch_idx->dtype() == DataType::kInt32 && + "`MhaFwdKvcache` requires `cache_batch_idx` to be `int32`."); + } + + if (leftpad_k.has_value()) { + assert(leftpad_k->ndim() == 1 && + "`MhaFwdKvcache` requires `leftpad_k` to be 1D."); + assert(leftpad_k->dtype() == DataType::kInt32 && + "`MhaFwdKvcache` requires `leftpad_k` to be `int32`."); + } + + if (rotary_cos.has_value() || rotary_sin.has_value()) { + assert(rotary_cos.has_value() && rotary_sin.has_value() && + "`MhaFwdKvcache` requires `rotary_cos` and `rotary_sin` " + "together."); + if (rotary_cos.has_value() && rotary_sin.has_value()) { + assert(rotary_cos->shape() == rotary_sin->shape() && + "`MhaFwdKvcache` requires rotary tensors to have same shape."); + } + } + + if (alibi_slopes.has_value()) { + assert((alibi_slopes->ndim() == 1 || alibi_slopes->ndim() == 2) && + "`MhaFwdKvcache` requires `alibi_slopes` to be 1D or 2D."); + assert(alibi_slopes->dtype() == DataType::kFloat32 && + "`MhaFwdKvcache` requires `alibi_slopes` to be `float32`."); + } + } + + virtual void operator()( + const Tensor q, const Tensor kcache, const Tensor vcache, + std::optional k, std::optional v, + std::optional seqlens_k, std::optional rotary_cos, + std::optional rotary_sin, std::optional cache_batch_idx, + std::optional leftpad_k, std::optional block_table, + std::optional alibi_slopes, Tensor out, float softmax_scale, + bool is_causal, int64_t window_size_left, int64_t window_size_right, + float softcap, bool is_rotary_interleaved, + int64_t num_splits = 0) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + Tensor::Size seqlen_q_{0}; + + Tensor::Size num_heads_{0}; + + Tensor::Size head_size_{0}; + + Tensor::Size num_kv_heads_{0}; + + Tensor::Size page_block_size_{0}; + + float softmax_scale_{0.0f}; + + bool is_causal_{false}; + + int64_t window_size_left_{-1}; + + int64_t window_size_right_{-1}; + + float softcap_{0.0f}; + + bool is_rotary_interleaved_{false}; + + int64_t num_splits_{0}; + + const DataType q_dtype_; + + bool has_k_{false}; + + bool has_v_{false}; + + bool has_seqlens_k_{false}; + + bool has_rotary_cos_{false}; + + bool has_rotary_sin_{false}; + + bool has_cache_batch_idx_{false}; + + bool has_leftpad_k_{false}; + + bool has_block_table_{false}; + + bool has_alibi_slopes_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/ascend/ops/fia_common_.h b/src/native/ascend/ops/fia_common_.h new file mode 100644 index 000000000..83ea2a3b4 --- /dev/null +++ b/src/native/ascend/ops/fia_common_.h @@ -0,0 +1,157 @@ +#ifndef INFINI_OPS_ASCEND_FIA_COMMON__H_ +#define INFINI_OPS_ASCEND_FIA_COMMON__H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "tensor.h" + +namespace infini::ops::ascend::fia { + +inline void AssertSupportedDtype(const DataType dtype, const char* op_name) { + (void)op_name; + assert((dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) && + "`FIA`: only `float16` and `bfloat16` are supported"); +} + +inline std::vector ReadIntTensor(const Tensor& tensor, + aclrtStream stream) { + assert(tensor.ndim() == 1 && "`FIA`: sequence tensor must be 1D"); + + const auto n = tensor.numel(); + std::vector result(n); + + if (tensor.dtype() == DataType::kInt32) { + std::vector tmp(n); + const int32_t* src = nullptr; + + if (tensor.device().type() == Device::Type::kCpu) { + src = static_cast(tensor.data()); + } else { + auto ret = aclrtMemcpyAsync(tmp.data(), n * sizeof(int32_t), + tensor.data(), n * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed"); + ret = aclrtSynchronizeStream(stream); + assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed"); + src = tmp.data(); + } + + for (std::size_t i = 0; i < n; ++i) { + result[i] = static_cast(src[i]); + } + + return result; + } + + if (tensor.dtype() == DataType::kInt64) { + if (tensor.device().type() == Device::Type::kCpu) { + const auto* src = static_cast(tensor.data()); + result.assign(src, src + n); + } else { + auto ret = aclrtMemcpyAsync(result.data(), n * sizeof(int64_t), + tensor.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed"); + ret = aclrtSynchronizeStream(stream); + assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed"); + } + + return result; + } + + assert(false && "`FIA`: sequence tensor must be `int32` or `int64`"); + + return result; +} + +inline aclIntArray* CreateCumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(cu_seqlens, stream); + assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch"); + + return aclCreateIntArray(values.data() + 1, + static_cast(values.size() - 1)); +} + +inline aclIntArray* CreateDiffSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(cu_seqlens, stream); + assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch"); + + std::vector lengths(values.size() - 1); + for (std::size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = values[i + 1] - values[i]; + } + + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +inline aclIntArray* CreateSeqLengths(const Tensor& seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(seqlens, stream); + + return aclCreateIntArray(values.data(), static_cast(values.size())); +} + +inline aclTensor* MakeCausalMask(void** mask_buf) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const auto mask_bytes = static_cast(mask_elems); + + auto ret = aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`FIA`: causal mask allocation failed"); + + std::vector host_mask(mask_elems); + for (int64_t row = 0; row < kMaskDim; ++row) { + for (int64_t col = 0; col < kMaskDim; ++col) { + host_mask[row * kMaskDim + col] = (col > row) ? 1 : 0; + } + } + + ret = aclrtMemcpy(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + assert(ret == ACL_SUCCESS && "`FIA`: causal mask upload failed"); + + std::vector shape = {kMaskDim, kMaskDim}; + std::vector strides = {kMaskDim, 1}; + std::vector storage = {mask_elems}; + + return aclCreateTensor(shape.data(), static_cast(shape.size()), + ACL_UINT8, strides.data(), 0, ACL_FORMAT_ND, + storage.data(), static_cast(storage.size()), + *mask_buf); +} + +inline void ResolveSparseMode(bool is_causal, int64_t window_left, + int64_t window_right, int64_t& sparse_mode, + int64_t& pre_tokens, int64_t& next_tokens) { + sparse_mode = 0; + pre_tokens = 2147483647; + next_tokens = 2147483647; + + if (is_causal) { + if (window_left >= 0) { + sparse_mode = 4; + pre_tokens = window_left; + } else { + sparse_mode = 3; + } + next_tokens = 0; + + return; + } + + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; +} + +} // namespace infini::ops::ascend::fia + +#endif diff --git a/src/native/ascend/ops/graph_cleanup_.h b/src/native/ascend/ops/graph_cleanup_.h new file mode 100644 index 000000000..8dfb61319 --- /dev/null +++ b/src/native/ascend/ops/graph_cleanup_.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_ASCEND_GRAPH_CLEANUP__H_ +#define INFINI_OPS_ASCEND_GRAPH_CLEANUP__H_ + +#include +#include +#include + +namespace infini::ops::ascend { + +class DeferredAclCleanupScope; + +namespace detail { + +inline thread_local DeferredAclCleanupScope* active_acl_cleanup_scope = nullptr; + +} // namespace detail + +class DeferredAclCleanupScope { + public: + DeferredAclCleanupScope() : previous_(detail::active_acl_cleanup_scope) { + detail::active_acl_cleanup_scope = this; + } + + ~DeferredAclCleanupScope() { + detail::active_acl_cleanup_scope = previous_; + + for (auto& cleanup : callbacks_) { + cleanup(); + } + } + + DeferredAclCleanupScope(const DeferredAclCleanupScope&) = delete; + + DeferredAclCleanupScope& operator=(const DeferredAclCleanupScope&) = delete; + + void Defer(std::function cleanup) { + callbacks_.push_back(std::move(cleanup)); + } + + std::vector> Release() { return std::move(callbacks_); } + + private: + DeferredAclCleanupScope* previous_; + + std::vector> callbacks_; +}; + +inline void DeferOrRunAclCleanup(std::function cleanup) { + if (detail::active_acl_cleanup_scope) { + detail::active_acl_cleanup_scope->Defer(std::move(cleanup)); + + return; + } + + cleanup(); +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/native/ascend/ops/mha_fwd_kvcache/kernel.h b/src/native/ascend/ops/mha_fwd_kvcache/kernel.h new file mode 100644 index 000000000..665ec3a7f --- /dev/null +++ b/src/native/ascend/ops/mha_fwd_kvcache/kernel.h @@ -0,0 +1,189 @@ +#ifndef INFINI_OPS_ASCEND_MHA_FWD_KVCACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_MHA_FWD_KVCACHE_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "base/mha_fwd_kvcache.h" +#include "native/ascend/common.h" +#include "native/ascend/ops/fia_common_.h" +#include "native/ascend/ops/graph_cleanup_.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public MhaFwdKvcache { + public: + Operator(const Tensor q, const Tensor kcache, const Tensor vcache, + std::optional k, std::optional v, + std::optional seqlens_k, std::optional rotary_cos, + std::optional rotary_sin, + std::optional cache_batch_idx, + std::optional leftpad_k, std::optional block_table, + std::optional alibi_slopes, Tensor out, float softmax_scale, + bool is_causal, int64_t window_size_left, int64_t window_size_right, + float softcap, bool is_rotary_interleaved, int64_t num_splits = 0) + : MhaFwdKvcache(q, kcache, vcache, k, v, seqlens_k, rotary_cos, + rotary_sin, cache_batch_idx, leftpad_k, block_table, + alibi_slopes, out, softmax_scale, is_causal, + window_size_left, window_size_right, softcap, + is_rotary_interleaved, num_splits) { + ascend::fia::AssertSupportedDtype(q.dtype(), "`MhaFwdKvcache`"); + + assert(seqlen_q_ == 1 && + "`MhaFwdKvcache`: this Ascend path supports decode " + "`seqlen_q == 1` only"); + assert(has_block_table_ && + "`MhaFwdKvcache`: this Ascend path requires paged KV " + "`block_table`"); + assert(has_seqlens_k_ && + "`MhaFwdKvcache`: this Ascend path requires `seqlens_k`"); + assert(page_block_size_ > 0 && page_block_size_ <= 512 && + page_block_size_ % 16 == 0 && + "`MhaFwdKvcache`: paged KV block size must be 16-aligned and " + "not exceed 512 for `float16` and `bfloat16`"); + + auto acl_dt = ascend::ToAclDtype(q.dtype()); + const auto B = static_cast(q.size(0)); + const auto N = static_cast(q.size(2)); + const auto D = static_cast(q.size(3)); + q_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt, + const_cast(q.data())); + out_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt, out.data()); + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + const int64_t num_blocks = kcache.size(0); + const int64_t block_size = kcache.size(1); + const int64_t num_kv_heads = kcache.size(2); + const int64_t head_dim = kcache.size(3); + // FIA decode expects paged KV descriptors as `[num_blocks, num_kv_heads, + // block_size, head_dim]`. The physical cache is stored as + // `[num_blocks, block_size, num_kv_heads, head_dim]`, so describe it as a + // strided view instead of moving data. + kv_shape_ = {num_blocks, num_kv_heads, block_size, head_dim}; + kv_strides_ = {block_size * num_kv_heads * head_dim, head_dim, + num_kv_heads * head_dim, 1}; + kv_storage_shape_ = {num_blocks * block_size * num_kv_heads * head_dim}; + kv_acl_dtype_ = acl_dt; + } + + void operator()( + const Tensor q, const Tensor kcache, const Tensor vcache, + std::optional k, std::optional v, + std::optional seqlens_k, std::optional rotary_cos, + std::optional rotary_sin, std::optional cache_batch_idx, + std::optional leftpad_k, std::optional block_table, + std::optional alibi_slopes, Tensor out, float softmax_scale, + bool is_causal, int64_t window_size_left, int64_t window_size_right, + float softcap, bool is_rotary_interleaved, + int64_t num_splits) const override { + (void)softmax_scale; + (void)is_causal; + (void)window_size_left; + (void)window_size_right; + (void)is_rotary_interleaved; + + assert(!k.has_value() && !v.has_value() && + "`MhaFwdKvcache`: appending new `k` / `v` into cache is not " + "supported by this Ascend path"); + assert(seqlens_k.has_value() && "`MhaFwdKvcache`: `seqlens_k` is required"); + assert(!rotary_cos.has_value() && !rotary_sin.has_value() && + "`MhaFwdKvcache`: rotary-on-append is not supported by this " + "Ascend path"); + assert(!cache_batch_idx.has_value() && + "`MhaFwdKvcache`: `cache_batch_idx` is not supported by paged KV"); + assert(!leftpad_k.has_value() && + "`MhaFwdKvcache`: `leftpad_k` is not supported by this Ascend path"); + assert(block_table.has_value() && + "`MhaFwdKvcache`: paged KV `block_table` is required"); + assert(!alibi_slopes.has_value() && + "`MhaFwdKvcache`: `alibi_slopes` is not supported by this Ascend " + "path"); + assert(softcap <= 0.0f && + "`MhaFwdKvcache`: `softcap` is not supported by this Ascend path"); + assert(num_splits == 0 && + "`MhaFwdKvcache`: split-KV is not supported by this Ascend path"); + + auto stream = static_cast(stream_); + auto seq_k = ascend::fia::CreateSeqLengths(seqlens_k.value(), stream); + + auto t_q = q_cache_.get(const_cast(q.data())); + auto t_out = out_cache_.get(out.data()); + auto t_block_table = + block_table_cache_.get(const_cast(block_table->data())); + + auto t_k = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dtype_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(kcache.data())); + auto t_v = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dtype_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(vcache.data())); + + const aclTensor* key_arr[] = {t_k}; + const aclTensor* value_arr[] = {t_v}; + auto key_list = aclCreateTensorList(key_arr, 1); + auto value_list = aclCreateTensorList(value_arr, 1); + + uint64_t ws_size = 0; + aclOpExecutor* executor = nullptr; + auto ret = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, value_list, + nullptr, // pseShift. + nullptr, // attenMask. + nullptr, seq_k, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, t_block_table, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, static_cast(num_heads_), softmax_scale_, 2147483647, + 2147483647, const_cast("BNSD"), + static_cast(num_kv_heads_), + 0, // sparseMode. + 0, // innerPrecise. + static_cast(page_block_size_), + 0, // antiquantMode. + false, // softmaxLseFlag. + 0, 0, 0, t_out, nullptr, &ws_size, &executor); + assert(ret == ACL_SUCCESS && + "`aclnnFusedInferAttentionScoreV4GetWorkspaceSize` failed"); + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_size, executor, stream); + assert(ret == ACL_SUCCESS && "`aclnnFusedInferAttentionScoreV4` failed"); + + // Keep per-call descriptors alive until the RI graph task update using + // them has completed. + ascend::DeferOrRunAclCleanup([key_list, value_list, seq_k]() { + aclDestroyTensorList(key_list); + aclDestroyTensorList(value_list); + aclDestroyIntArray(seq_k); + }); + } + + private: + mutable ascend::AclTensorCache q_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dtype_{ACL_DT_UNDEFINED}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_mha_fwd_kvcache.py b/tests/test_mha_fwd_kvcache.py new file mode 100644 index 000000000..6958b16fb --- /dev/null +++ b/tests/test_mha_fwd_kvcache.py @@ -0,0 +1,189 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "batch_size, kv_len, num_heads, num_kv_heads, head_size", + ( + (1, 30, 7, 1, 128), + (2, 17, 8, 2, 64), + (3, 23, 16, 4, 128), + ), +) +@pytest.mark.parametrize("seqlens_on_cpu", (False, True)) +@pytest.mark.parametrize("block_size", (64, 128, 256)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_mha_fwd_kvcache_paged_decode( + batch_size, + kv_len, + num_heads, + num_kv_heads, + head_size, + seqlens_on_cpu, + block_size, + implementation_index, + dtype, + device, + rtol, + atol, +): + blocks_per_seq = (kv_len + block_size - 1) // block_size + num_blocks = batch_size * blocks_per_seq + scale = 1.0 / head_size**0.5 + + q = randn_strided( + (batch_size, 1, num_heads, head_size), None, dtype=dtype, device=device + ) + kcache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + vcache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + out = torch.empty_like(q) + block_table = torch.empty( + (batch_size, blocks_per_seq), dtype=torch.int32, device=device + ) + + for batch in range(batch_size): + for block in range(blocks_per_seq): + block_table[batch, block] = batch * blocks_per_seq + block + + seqlens_device = "cpu" if seqlens_on_cpu else device + seqlens_k = torch.full( + (batch_size,), kv_len, dtype=torch.int32, device=seqlens_device + ) + + return Payload( + lambda *args, **kwargs: _mha_fwd_kvcache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_mha_fwd_kvcache, + (q, kcache, vcache, seqlens_k, block_table, out), + { + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "block_size": block_size, + "scale": scale, + }, + rtol=rtol, + atol=atol, + ) + + +def _mha_fwd_kvcache( + q, + kcache, + vcache, + seqlens_k, + block_table, + out, + *, + num_heads, + num_kv_heads, + block_size, + scale, + implementation_index=0, +): + del num_heads, num_kv_heads, block_size + + infini.ops.mha_fwd_kvcache( + q, + kcache, + vcache, + None, + None, + seqlens_k, + None, + None, + None, + None, + block_table, + None, + out, + scale, + True, + -1, + 0, + 0.0, + False, + 0, + implementation_index=implementation_index, + stream=get_stream(q.device), + ) + + return out + + +def _ref_mha_fwd_kvcache( + q, + kcache, + vcache, + seqlens_k, + block_table, + out, + *, + num_heads, + num_kv_heads, + block_size, + scale, +): + del out + + q_cpu = q.cpu().float() + kcache_cpu = kcache.cpu().float() + vcache_cpu = vcache.cpu().float() + seqlens_cpu = seqlens_k.cpu() + block_table_cpu = block_table.cpu() + outputs = [] + + for batch in range(q.shape[0]): + kv_len = int(seqlens_cpu[batch].item()) + remaining = kv_len + k_pages = [] + v_pages = [] + + for block in block_table_cpu[batch]: + if remaining <= 0: + break + + take = min(remaining, block_size) + block_id = int(block.item()) + k_pages.append(kcache_cpu[block_id, :take].transpose(0, 1)) + v_pages.append(vcache_cpu[block_id, :take].transpose(0, 1)) + remaining -= take + + k_i = torch.cat(k_pages, dim=1) + v_i = torch.cat(v_pages, dim=1) + + if num_kv_heads < num_heads: + repeat = num_heads // num_kv_heads + k_i = k_i.repeat_interleave(repeat, dim=0) + v_i = v_i.repeat_interleave(repeat, dim=0) + + ref = torch.nn.functional.scaled_dot_product_attention( + q_cpu[batch].transpose(0, 1).unsqueeze(0), + k_i.unsqueeze(0), + v_i.unsqueeze(0), + scale=scale, + is_causal=False, + ) + outputs.append(ref.squeeze(0).transpose(0, 1).unsqueeze(0).to(q.dtype)) + + return torch.cat(outputs, dim=0).to(q.device)