From f47af18382ad15e0a592f19d140b54923be0a6e3 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 30 Jun 2026 15:09:33 +0800 Subject: [PATCH] feat(ascend): add `mha_varlen_fwd` operator --- src/base/mha_varlen_fwd.h | 199 ++++++++++++ src/native/ascend/ops/fia_common_.h | 157 +++++++++ src/native/ascend/ops/graph_cleanup_.h | 60 ++++ src/native/ascend/ops/mha_varlen_fwd/kernel.h | 219 +++++++++++++ tests/test_mha_varlen_fwd.py | 306 ++++++++++++++++++ 5 files changed, 941 insertions(+) create mode 100644 src/base/mha_varlen_fwd.h create mode 100644 src/native/ascend/ops/fia_common_.h create mode 100644 src/native/ascend/ops/graph_cleanup_.h create mode 100644 src/native/ascend/ops/mha_varlen_fwd/kernel.h create mode 100644 tests/test_mha_varlen_fwd.py diff --git a/src/base/mha_varlen_fwd.h b/src/base/mha_varlen_fwd.h new file mode 100644 index 000000000..bd04f3c25 --- /dev/null +++ b/src/base/mha_varlen_fwd.h @@ -0,0 +1,199 @@ +#ifndef INFINI_OPS_BASE_MHA_VARLEN_FWD_H_ +#define INFINI_OPS_BASE_MHA_VARLEN_FWD_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// FlashAttention-compatible variable-length forward attention. +// +// Layout follows `flash_attn` `varlen_fwd`: +// `q`: `[total_q, num_heads, head_size]`. +// `k` / `v`: `[total_k, num_kv_heads, head_size]`. +// Paged `k` / `v`: `[num_blocks, block_size, num_kv_heads, head_size]` +// when `block_table` is supplied. +// `cu_seqlens_q` / `cu_seqlens_k`: `[batch_size + 1]`. +// +// InfiniOps is an in-place operator API, so `out` must be supplied even +// though FlashAttention can allocate it when omitted. +// This signature intentionally keeps `out` near the FlashAttention argument +// position for compatibility with existing callers; this is an exception to +// the normal InfiniOps output-last convention. +class MhaVarlenFwd : public Operator { + public: + MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out, + const Tensor cu_seqlens_q, const Tensor cu_seqlens_k, + std::optional seqused_k, std::optional leftpad_k, + std::optional block_table, + std::optional alibi_slopes, int64_t max_seqlen_q, + int64_t max_seqlen_k, float p_dropout, float softmax_scale, + bool zero_tensors, bool is_causal, int64_t window_size_left, + int64_t window_size_right, float softcap, bool return_softmax, + std::optional generator = std::nullopt, + int64_t num_splits = 0) + : batch_size_{cu_seqlens_q.numel() - 1}, + total_q_{q.size(0)}, + num_heads_{q.size(1)}, + head_size_{q.size(2)}, + num_kv_heads_{block_table.has_value() ? k.size(2) : k.size(1)}, + max_seqlen_q_{max_seqlen_q}, + max_seqlen_k_{max_seqlen_k}, + p_dropout_{p_dropout}, + softmax_scale_{softmax_scale}, + zero_tensors_{zero_tensors}, + is_causal_{is_causal}, + window_size_left_{window_size_left}, + window_size_right_{window_size_right}, + softcap_{softcap}, + return_softmax_{return_softmax}, + num_splits_{num_splits}, + q_dtype_{q.dtype()}, + has_seqused_k_{seqused_k.has_value()}, + has_leftpad_k_{leftpad_k.has_value()}, + has_block_table_{block_table.has_value()}, + has_alibi_slopes_{alibi_slopes.has_value()}, + has_generator_{generator.has_value()} { + assert(q.ndim() == 3 && + "`MhaVarlenFwd` requires `q` to be `[total_q, heads, dim]`"); + if (has_block_table_) { + assert(k.ndim() == 4 && v.ndim() == 4 && + "`MhaVarlenFwd` with `block_table` requires paged `k` / `v` to " + "be `[num_blocks, block_size, heads, dim]`."); + assert(block_table->ndim() == 2 && + "`MhaVarlenFwd` requires `block_table` to be 2D."); + assert(block_table->dtype() == DataType::kInt32 && + "`MhaVarlenFwd` requires `block_table` to be `int32`."); + assert(k.shape() == v.shape() && + "`MhaVarlenFwd` requires paged `k` and `v` to have same shape."); + assert(k.size(3) == head_size_ && + "`MhaVarlenFwd` requires paged `k` / `v` last dim to match `q`."); + } else { + assert(k.ndim() == 3 && v.ndim() == 3 && + "`MhaVarlenFwd` requires `k` / `v` to be " + "`[total_k, heads, dim]`."); + assert(k.shape() == v.shape() && + "`MhaVarlenFwd` requires `k` and `v` to have same shape."); + assert(k.size(2) == head_size_ && + "`MhaVarlenFwd` requires `k` / `v` last dim to match `q`."); + } + assert(k.dtype() == q.dtype() && v.dtype() == q.dtype() && + "`MhaVarlenFwd` requires `q`, `k`, and `v` to have same dtype"); + assert(out.dtype() == q.dtype() && + "`MhaVarlenFwd` requires `out` to have same dtype as `q`."); + assert(k.stride(-1) == 1 && v.stride(-1) == 1 && q.stride(-1) == 1 && + "`MhaVarlenFwd` requires contiguous last dimension"); + assert(num_heads_ % num_kv_heads_ == 0 && + "`MhaVarlenFwd` requires `num_heads` divisible by `num_kv_heads`"); + assert(head_size_ <= 256 && + "`MhaVarlenFwd` supports `head_size` up to 256"); + assert(head_size_ % 8 == 0 && + "`MhaVarlenFwd` requires `head_size` to be a multiple of 8 in " + "InfiniOps v1."); + assert(cu_seqlens_q.ndim() == 1 && cu_seqlens_k.ndim() == 1 && + "`MhaVarlenFwd` requires 1D `cu_seqlens_q` / `cu_seqlens_k`"); + assert(cu_seqlens_q.dtype() == DataType::kInt32 && + cu_seqlens_k.dtype() == DataType::kInt32 && + "`MhaVarlenFwd` requires `cu_seqlens_q` and `cu_seqlens_k` to be " + "`int32`."); + assert(cu_seqlens_q.numel() > 1 && + "`MhaVarlenFwd` requires non-empty `cu_seqlens`."); + assert(cu_seqlens_q.numel() == cu_seqlens_k.numel() && + "`MhaVarlenFwd` requires matching `cu_seqlens` lengths"); + assert(out.shape() == q.shape() && + "`MhaVarlenFwd` requires `out` to match `q` shape"); + assert(out.stride(-1) == 1 && + "`MhaVarlenFwd` requires `out` contiguous last dimension"); + assert(p_dropout >= 0.0f && p_dropout <= 1.0f && + "`MhaVarlenFwd` requires `p_dropout` in `[0, 1]`."); + assert(std::isfinite(softmax_scale_) && + "`MhaVarlenFwd` requires finite `softmax_scale`."); + assert(max_seqlen_q_ >= 0 && max_seqlen_k_ >= 0 && + "`MhaVarlenFwd` requires non-negative max sequence lengths."); + + if (seqused_k.has_value()) { + assert(seqused_k->ndim() == 1 && + "`MhaVarlenFwd` requires `seqused_k` to be 1D."); + assert(seqused_k->dtype() == DataType::kInt32 && + "`MhaVarlenFwd` requires `seqused_k` to be `int32`."); + } + + if (leftpad_k.has_value()) { + assert(leftpad_k->ndim() == 1 && + "`MhaVarlenFwd` requires `leftpad_k` to be 1D."); + assert(leftpad_k->dtype() == DataType::kInt32 && + "`MhaVarlenFwd` requires `leftpad_k` to be `int32`."); + } + + if (alibi_slopes.has_value()) { + assert((alibi_slopes->ndim() == 1 || alibi_slopes->ndim() == 2) && + "`MhaVarlenFwd` requires `alibi_slopes` to be 1D or 2D."); + assert(alibi_slopes->dtype() == DataType::kFloat32 && + "`MhaVarlenFwd` requires `alibi_slopes` to be `float32`."); + } + } + + virtual void operator()( + const Tensor q, const Tensor k, const Tensor v, Tensor out, + const Tensor cu_seqlens_q, const Tensor cu_seqlens_k, + std::optional seqused_k, std::optional leftpad_k, + std::optional block_table, std::optional alibi_slopes, + int64_t max_seqlen_q, int64_t max_seqlen_k, float p_dropout, + float softmax_scale, bool zero_tensors, bool is_causal, + int64_t window_size_left, int64_t window_size_right, float softcap, + bool return_softmax, std::optional generator = std::nullopt, + int64_t num_splits = 0) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + Tensor::Size total_q_{0}; + + Tensor::Size num_heads_{0}; + + Tensor::Size head_size_{0}; + + Tensor::Size num_kv_heads_{0}; + + int64_t max_seqlen_q_{0}; + + int64_t max_seqlen_k_{0}; + + float p_dropout_{0.0f}; + + float softmax_scale_{0.0f}; + + bool zero_tensors_{false}; + + bool is_causal_{false}; + + int64_t window_size_left_{-1}; + + int64_t window_size_right_{-1}; + + float softcap_{0.0f}; + + bool return_softmax_{false}; + + int64_t num_splits_{0}; + + const DataType q_dtype_; + + bool has_seqused_k_{false}; + + bool has_leftpad_k_{false}; + + bool has_block_table_{false}; + + bool has_alibi_slopes_{false}; + + bool has_generator_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/ascend/ops/fia_common_.h b/src/native/ascend/ops/fia_common_.h new file mode 100644 index 000000000..83ea2a3b4 --- /dev/null +++ b/src/native/ascend/ops/fia_common_.h @@ -0,0 +1,157 @@ +#ifndef INFINI_OPS_ASCEND_FIA_COMMON__H_ +#define INFINI_OPS_ASCEND_FIA_COMMON__H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "data_type.h" +#include "native/ascend/common.h" +#include "tensor.h" + +namespace infini::ops::ascend::fia { + +inline void AssertSupportedDtype(const DataType dtype, const char* op_name) { + (void)op_name; + assert((dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) && + "`FIA`: only `float16` and `bfloat16` are supported"); +} + +inline std::vector ReadIntTensor(const Tensor& tensor, + aclrtStream stream) { + assert(tensor.ndim() == 1 && "`FIA`: sequence tensor must be 1D"); + + const auto n = tensor.numel(); + std::vector result(n); + + if (tensor.dtype() == DataType::kInt32) { + std::vector tmp(n); + const int32_t* src = nullptr; + + if (tensor.device().type() == Device::Type::kCpu) { + src = static_cast(tensor.data()); + } else { + auto ret = aclrtMemcpyAsync(tmp.data(), n * sizeof(int32_t), + tensor.data(), n * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed"); + ret = aclrtSynchronizeStream(stream); + assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed"); + src = tmp.data(); + } + + for (std::size_t i = 0; i < n; ++i) { + result[i] = static_cast(src[i]); + } + + return result; + } + + if (tensor.dtype() == DataType::kInt64) { + if (tensor.device().type() == Device::Type::kCpu) { + const auto* src = static_cast(tensor.data()); + result.assign(src, src + n); + } else { + auto ret = aclrtMemcpyAsync(result.data(), n * sizeof(int64_t), + tensor.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + assert(ret == ACL_SUCCESS && "`FIA`: D2H copy failed"); + ret = aclrtSynchronizeStream(stream); + assert(ret == ACL_SUCCESS && "`FIA`: stream synchronize failed"); + } + + return result; + } + + assert(false && "`FIA`: sequence tensor must be `int32` or `int64`"); + + return result; +} + +inline aclIntArray* CreateCumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(cu_seqlens, stream); + assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch"); + + return aclCreateIntArray(values.data() + 1, + static_cast(values.size() - 1)); +} + +inline aclIntArray* CreateDiffSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(cu_seqlens, stream); + assert(values.size() > 1 && "`FIA`: `cu_seqlens` must contain a batch"); + + std::vector lengths(values.size() - 1); + for (std::size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = values[i + 1] - values[i]; + } + + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +inline aclIntArray* CreateSeqLengths(const Tensor& seqlens, + aclrtStream stream) { + auto values = ReadIntTensor(seqlens, stream); + + return aclCreateIntArray(values.data(), static_cast(values.size())); +} + +inline aclTensor* MakeCausalMask(void** mask_buf) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const auto mask_bytes = static_cast(mask_elems); + + auto ret = aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`FIA`: causal mask allocation failed"); + + std::vector host_mask(mask_elems); + for (int64_t row = 0; row < kMaskDim; ++row) { + for (int64_t col = 0; col < kMaskDim; ++col) { + host_mask[row * kMaskDim + col] = (col > row) ? 1 : 0; + } + } + + ret = aclrtMemcpy(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + assert(ret == ACL_SUCCESS && "`FIA`: causal mask upload failed"); + + std::vector shape = {kMaskDim, kMaskDim}; + std::vector strides = {kMaskDim, 1}; + std::vector storage = {mask_elems}; + + return aclCreateTensor(shape.data(), static_cast(shape.size()), + ACL_UINT8, strides.data(), 0, ACL_FORMAT_ND, + storage.data(), static_cast(storage.size()), + *mask_buf); +} + +inline void ResolveSparseMode(bool is_causal, int64_t window_left, + int64_t window_right, int64_t& sparse_mode, + int64_t& pre_tokens, int64_t& next_tokens) { + sparse_mode = 0; + pre_tokens = 2147483647; + next_tokens = 2147483647; + + if (is_causal) { + if (window_left >= 0) { + sparse_mode = 4; + pre_tokens = window_left; + } else { + sparse_mode = 3; + } + next_tokens = 0; + + return; + } + + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; +} + +} // namespace infini::ops::ascend::fia + +#endif diff --git a/src/native/ascend/ops/graph_cleanup_.h b/src/native/ascend/ops/graph_cleanup_.h new file mode 100644 index 000000000..8dfb61319 --- /dev/null +++ b/src/native/ascend/ops/graph_cleanup_.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_ASCEND_GRAPH_CLEANUP__H_ +#define INFINI_OPS_ASCEND_GRAPH_CLEANUP__H_ + +#include +#include +#include + +namespace infini::ops::ascend { + +class DeferredAclCleanupScope; + +namespace detail { + +inline thread_local DeferredAclCleanupScope* active_acl_cleanup_scope = nullptr; + +} // namespace detail + +class DeferredAclCleanupScope { + public: + DeferredAclCleanupScope() : previous_(detail::active_acl_cleanup_scope) { + detail::active_acl_cleanup_scope = this; + } + + ~DeferredAclCleanupScope() { + detail::active_acl_cleanup_scope = previous_; + + for (auto& cleanup : callbacks_) { + cleanup(); + } + } + + DeferredAclCleanupScope(const DeferredAclCleanupScope&) = delete; + + DeferredAclCleanupScope& operator=(const DeferredAclCleanupScope&) = delete; + + void Defer(std::function cleanup) { + callbacks_.push_back(std::move(cleanup)); + } + + std::vector> Release() { return std::move(callbacks_); } + + private: + DeferredAclCleanupScope* previous_; + + std::vector> callbacks_; +}; + +inline void DeferOrRunAclCleanup(std::function cleanup) { + if (detail::active_acl_cleanup_scope) { + detail::active_acl_cleanup_scope->Defer(std::move(cleanup)); + + return; + } + + cleanup(); +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/native/ascend/ops/mha_varlen_fwd/kernel.h b/src/native/ascend/ops/mha_varlen_fwd/kernel.h new file mode 100644 index 000000000..d6022ec66 --- /dev/null +++ b/src/native/ascend/ops/mha_varlen_fwd/kernel.h @@ -0,0 +1,219 @@ +#ifndef INFINI_OPS_ASCEND_MHA_VARLEN_FWD_KERNEL_H_ +#define INFINI_OPS_ASCEND_MHA_VARLEN_FWD_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "base/mha_varlen_fwd.h" +#include "native/ascend/common.h" +#include "native/ascend/ops/fia_common_.h" +#include "native/ascend/ops/graph_cleanup_.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public MhaVarlenFwd { + public: + Operator(const Tensor q, const Tensor k, const Tensor v, Tensor out, + const Tensor cu_seqlens_q, const Tensor cu_seqlens_k, + std::optional seqused_k, std::optional leftpad_k, + std::optional block_table, + std::optional alibi_slopes, int64_t max_seqlen_q, + int64_t max_seqlen_k, float p_dropout, float softmax_scale, + bool zero_tensors, bool is_causal, int64_t window_size_left, + int64_t window_size_right, float softcap, bool return_softmax, + std::optional generator = std::nullopt, + int64_t num_splits = 0) + : MhaVarlenFwd(q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, + leftpad_k, block_table, alibi_slopes, max_seqlen_q, + max_seqlen_k, p_dropout, softmax_scale, zero_tensors, + is_causal, window_size_left, window_size_right, softcap, + return_softmax, generator, num_splits), + q_cache_(q), + out_cache_(out) { + ascend::fia::AssertSupportedDtype(q.dtype(), "`MhaVarlenFwd`"); + + if (has_block_table_) { + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + const int64_t num_blocks = k.size(0); + const int64_t block_size = k.size(1); + const int64_t kv_nd = k.size(2) * k.size(3); + kv_shape_ = {num_blocks, block_size, kv_nd}; + kv_strides_ = {block_size * kv_nd, kv_nd, 1}; + kv_storage_shape_ = {num_blocks * block_size * kv_nd}; + kv_acl_dtype_ = ascend::ToAclDtype(q.dtype()); + page_block_size_ = block_size; + assert(page_block_size_ % 16 == 0 && + "`MhaVarlenFwd`: paged KV `block_size` must be 16-aligned for " + "`float16` and `bfloat16`."); + } + + if (is_causal_) { + assert(max_seqlen_q_ <= 2048 && max_seqlen_k_ <= 2048 && + "`MhaVarlenFwd`: causal FIA mask currently supports " + "`max_seqlen <= 2048`."); + causal_mask_ = ascend::fia::MakeCausalMask(&causal_mask_buf_); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + if (causal_mask_) aclDestroyTensor(causal_mask_); + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); + } + + void operator()(const Tensor q, const Tensor k, const Tensor v, Tensor out, + const Tensor cu_seqlens_q, const Tensor cu_seqlens_k, + std::optional seqused_k, + std::optional leftpad_k, + std::optional block_table, + std::optional alibi_slopes, int64_t max_seqlen_q, + int64_t max_seqlen_k, float p_dropout, float softmax_scale, + bool zero_tensors, bool is_causal, int64_t window_size_left, + int64_t window_size_right, float softcap, bool return_softmax, + std::optional generator, + int64_t num_splits) const override { + (void)max_seqlen_q; + (void)max_seqlen_k; + (void)softmax_scale; + (void)is_causal; + (void)window_size_left; + (void)window_size_right; + + assert(!seqused_k.has_value() && + "`MhaVarlenFwd`: `seqused_k` is not supported by this Ascend path"); + assert(!leftpad_k.has_value() && + "`MhaVarlenFwd`: `leftpad_k` is not supported by this Ascend path"); + assert(block_table.has_value() == has_block_table_ && + "`MhaVarlenFwd`: `block_table` presence changed after descriptor " + "creation"); + assert(!alibi_slopes.has_value() && + "`MhaVarlenFwd`: `alibi_slopes` is not supported by this Ascend " + "path"); + assert(p_dropout == 0.0f && + "`MhaVarlenFwd`: dropout is not supported by this Ascend path"); + assert(!zero_tensors && + "`MhaVarlenFwd`: `zero_tensors` is not supported by this Ascend " + "path"); + assert(softcap <= 0.0f && + "`MhaVarlenFwd`: `softcap` is not supported by this Ascend path"); + assert(!return_softmax && + "`MhaVarlenFwd`: returning softmax is not supported by InfiniOps " + "in-place API"); + assert(!generator.has_value() && + "`MhaVarlenFwd`: `generator` is only meaningful for dropout"); + assert(num_splits == 0 && + "`MhaVarlenFwd`: split-KV is not supported by this Ascend path"); + + auto stream = static_cast(stream_); + int64_t sparse_mode = 0; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + ascend::fia::ResolveSparseMode(is_causal_, window_size_left_, + window_size_right_, sparse_mode, pre_tokens, + next_tokens); + + auto seq_q = ascend::fia::CreateCumSeqLengths(cu_seqlens_q, stream); + auto seq_k = has_block_table_ + ? ascend::fia::CreateDiffSeqLengths(cu_seqlens_k, stream) + : ascend::fia::CreateCumSeqLengths(cu_seqlens_k, stream); + + auto t_q = q_cache_.get(const_cast(q.data())); + auto t_out = out_cache_.get(out.data()); + aclTensor* t_block_table = nullptr; + aclTensor* t_k = nullptr; + aclTensor* t_v = nullptr; + + if (has_block_table_) { + assert(page_block_size_ == static_cast(k.size(1)) && + "`MhaVarlenFwd`: paged KV `block_size` changed after descriptor " + "creation"); + t_block_table = + block_table_cache_.get(const_cast(block_table->data())); + t_k = aclCreateTensor(kv_shape_.data(), + static_cast(kv_shape_.size()), + kv_acl_dtype_, kv_strides_.data(), 0, ACL_FORMAT_ND, + kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(k.data())); + t_v = aclCreateTensor(kv_shape_.data(), + static_cast(kv_shape_.size()), + kv_acl_dtype_, kv_strides_.data(), 0, ACL_FORMAT_ND, + kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(v.data())); + } else { + t_k = ascend::BuildAclTensor(k); + t_v = ascend::BuildAclTensor(v); + } + + const aclTensor* key_arr[] = {t_k}; + const aclTensor* value_arr[] = {t_v}; + auto key_list = aclCreateTensorList(key_arr, 1); + auto value_list = aclCreateTensorList(value_arr, 1); + + uint64_t ws_size = 0; + aclOpExecutor* executor = nullptr; + auto ret = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, value_list, + nullptr, // pseShift. + causal_mask_, // attenMask. + seq_q, seq_k, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, t_block_table, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, static_cast(num_heads_), softmax_scale_, pre_tokens, + next_tokens, const_cast("TND"), + static_cast(num_kv_heads_), sparse_mode, + 0, // innerPrecise. + page_block_size_, // blockSize. + 0, // antiquantMode. + false, // softmaxLseFlag. + 0, 0, 0, t_out, nullptr, &ws_size, &executor); + assert(ret == ACL_SUCCESS && + "`aclnnFusedInferAttentionScoreV4GetWorkspaceSize` failed"); + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_size, executor, stream); + assert(ret == ACL_SUCCESS && "`aclnnFusedInferAttentionScoreV4` failed"); + + ascend::DeferOrRunAclCleanup([key_list, value_list, seq_q, seq_k]() { + aclDestroyTensorList(key_list); + aclDestroyTensorList(value_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_k); + }); + } + + private: + mutable ascend::AclTensorCache q_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + aclTensor* causal_mask_ = nullptr; + + void* causal_mask_buf_ = nullptr; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dtype_{ACL_DT_UNDEFINED}; + + int64_t page_block_size_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_mha_varlen_fwd.py b/tests/test_mha_varlen_fwd.py new file mode 100644 index 000000000..35dd14de6 --- /dev/null +++ b/tests/test_mha_varlen_fwd.py @@ -0,0 +1,306 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "seq_lens, num_heads, num_kv_heads, head_size", + ( + ((16,), 8, 8, 64), + ((7, 11, 5), 8, 2, 64), + ((9, 3), 16, 4, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_mha_varlen_fwd( + seq_lens, + num_heads, + num_kv_heads, + head_size, + implementation_index, + dtype, + device, + rtol, + atol, +): + scale = 1.0 / head_size**0.5 + total_tokens = sum(seq_lens) + q = randn_strided( + (total_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + k = randn_strided( + (total_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + v = randn_strided( + (total_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + out = torch.empty_like(q) + cu = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int32, + device=device, + ) + + return Payload( + lambda *args, **kwargs: _mha_varlen_fwd( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_mha_varlen_fwd, + (q, k, v, out, cu, cu), + { + "seq_lens": seq_lens, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_size": head_size, + "scale": scale, + "causal": True, + }, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "seq_lens, num_heads, num_kv_heads, head_size", + ( + ((5, 3), 8, 2, 64), + ((7,), 16, 4, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_mha_varlen_fwd_paged_kv( + seq_lens, + num_heads, + num_kv_heads, + head_size, + implementation_index, + dtype, + device, + rtol, + atol, +): + block_size = 256 + batch_size = len(seq_lens) + blocks_per_seq = (max(seq_lens) + block_size - 1) // block_size + num_blocks = batch_size * blocks_per_seq + scale = 1.0 / head_size**0.5 + total_tokens = sum(seq_lens) + + q = randn_strided( + (total_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + kcache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + vcache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + out = torch.empty_like(q) + cu = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(batch_size)], + dtype=torch.int32, + device=device, + ) + block_table = torch.empty( + (batch_size, blocks_per_seq), dtype=torch.int32, device=device + ) + + for batch in range(batch_size): + for block in range(blocks_per_seq): + block_table[batch, block] = batch * blocks_per_seq + block + + return Payload( + lambda *args, **kwargs: _mha_varlen_fwd( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_mha_varlen_fwd_paged_kv, + (q, kcache, vcache, out, cu, cu, block_table), + { + "seq_lens": seq_lens, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_size": head_size, + "block_size": block_size, + "scale": scale, + "causal": True, + }, + rtol=rtol, + atol=atol, + ) + + +def _mha_varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + block_table=None, + *, + seq_lens, + num_heads, + num_kv_heads, + head_size, + block_size=None, + scale, + causal, + implementation_index=0, +): + del block_size + + infini.ops.mha_varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + block_table, + None, + max(seq_lens), + max(seq_lens), + 0.0, + scale, + False, + causal, + -1, + 0, + 0.0, + False, + None, + 0, + implementation_index=implementation_index, + stream=get_stream(q.device), + ) + + return out + + +def _ref_mha_varlen_fwd_paged_kv( + q, + kcache, + vcache, + out, + cu_seqlens_q, + cu_seqlens_k, + block_table, + *, + seq_lens, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal, +): + del out, cu_seqlens_q, cu_seqlens_k, head_size + + q_cpu = q.cpu().float() + kcache_cpu = kcache.cpu().float() + vcache_cpu = vcache.cpu().float() + block_table_cpu = block_table.cpu() + outputs = [] + q_offset = 0 + + for batch, seq_len in enumerate(seq_lens): + remaining = seq_len + k_pages = [] + v_pages = [] + + for block in block_table_cpu[batch]: + if remaining <= 0: + break + + take = min(remaining, block_size) + block_id = int(block.item()) + k_pages.append(kcache_cpu[block_id, :take].transpose(0, 1)) + v_pages.append(vcache_cpu[block_id, :take].transpose(0, 1)) + remaining -= take + + k_i = torch.cat(k_pages, dim=1) + v_i = torch.cat(v_pages, dim=1) + + if num_kv_heads < num_heads: + repeat = num_heads // num_kv_heads + k_i = k_i.repeat_interleave(repeat, dim=0) + v_i = v_i.repeat_interleave(repeat, dim=0) + + ref = torch.nn.functional.scaled_dot_product_attention( + q_cpu[q_offset : q_offset + seq_len].transpose(0, 1).unsqueeze(0), + k_i.unsqueeze(0), + v_i.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + outputs.append(ref.squeeze(0).transpose(0, 1).to(q.dtype)) + q_offset += seq_len + + return torch.cat(outputs, dim=0).to(q.device) + + +def _ref_mha_varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + *, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal, +): + del out, cu_seqlens_q, cu_seqlens_k, head_size + + outputs = [] + offset = 0 + + for seq_len in seq_lens: + q_i = q[offset : offset + seq_len].cpu().float().transpose(0, 1) + k_i = k[offset : offset + seq_len].cpu().float().transpose(0, 1) + v_i = v[offset : offset + seq_len].cpu().float().transpose(0, 1) + + if num_kv_heads < num_heads: + repeat = num_heads // num_kv_heads + k_i = k_i.repeat_interleave(repeat, dim=0) + v_i = v_i.repeat_interleave(repeat, dim=0) + + ref = torch.nn.functional.scaled_dot_product_attention( + q_i.unsqueeze(0), + k_i.unsqueeze(0), + v_i.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + outputs.append(ref.squeeze(0).transpose(0, 1).to(q.dtype)) + offset += seq_len + + return torch.cat(outputs, dim=0).to(q.device)