From 73867ca7707430c2f36acdfcda8e94d293bdc908 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 30 Jun 2026 15:09:32 +0800 Subject: [PATCH] feat(ascend): add `paged_attention` operator --- src/base/paged_attention.h | 185 ++++++ .../ascend/ops/paged_attention/kernel_atb.h | 283 +++++++++ tests/test_paged_attention.py | 554 ++++++++++++++++++ 3 files changed, 1022 insertions(+) create mode 100644 src/base/paged_attention.h create mode 100644 src/native/ascend/ops/paged_attention/kernel_atb.h create mode 100644 tests/test_paged_attention.py diff --git a/src/base/paged_attention.h b/src/base/paged_attention.h new file mode 100644 index 000000000..fa9f9a5db --- /dev/null +++ b/src/base/paged_attention.h @@ -0,0 +1,185 @@ +#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_H_ +#define INFINI_OPS_BASE_PAGED_ATTENTION_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Paged decode attention operator. +// +// Performs multi-head attention over paged KV caches for decode (single-token +// queries per sequence). +// +// Interface follows vLLM's paged attention convention: +// - vLLM CUDA: `torch.ops.vllm.paged_attention_v1` uses the same query +// shape [batch, num_heads, head_size] and seq_lens [batch] int32. +// KV cache differs (5D on CUDA for vectorization, 4D here). +// - vLLM-Ascend: `torch_npu._npu_paged_attention` wraps ATB +// `PagedAttentionParam` with default `inputLayout` (`TYPE_BSND`). +// - ATB `PagedAttentionParam`: `headNum`, `kvHeadNum`, `qkScale`, +// `maskType` (default NORM), `inputLayout` (default `TYPE_BSND`). +// +// Input layout (BSND with S=1 for decode): +// query : [batch, num_heads, head_size] +// key_cache : [num_blocks, block_size, num_kv_heads, head_size] +// value_cache : [num_blocks, block_size, num_kv_heads, head_size] +// seq_lens : [batch] int32 — total context length per sequence +// block_table : [batch, max_num_blocks_per_seq] int32 +// +// Output layout: +// output : [batch, num_heads, head_size] +// +// Optional host tensors: `seq_lens_host` and `block_table_host` are CPU +// mirrors of `seq_lens` and `block_table`. They exist because CANN's +// paged-attention APIs mandate CPU-resident metadata — aclnn declares +// `qSeqLens` as a CPU tensor in its signature, and ATB +// `PagedAttentionParam` reads `aclIntArray*` parameters from the +// `hostData` field at `aclnnRunner::Setup()` time. Without caller- +// provided host tensors, the kernel must synchronously D2H-copy both +// each call, which (a) blocks the stream and (b) prevents NPUGraph +// capture (sync copies are not capturable). When the caller already +// has CPU-pinned copies (e.g. vLLM's `optimistic_seq_lens_cpu` and +// `BlockTable.get_cpu_tensor()`), passing them through lets the kernel +// skip both D2H copies and be captured into a full NPUGraph. +class PagedAttention : public Operator { + public: + // Paged attention follows vLLM naming. `output` is explicit because + // InfiniOps operators are in-place; it remains before optional host mirrors + // to preserve the existing call surface. + PagedAttention(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) + : batch_size_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_cache_shape_{key_cache.shape()}, + value_cache_shape_{value_cache.shape()}, + seq_lens_shape_{seq_lens.shape()}, + block_table_shape_{block_table.shape()}, + output_shape_{output.shape()}, + has_seq_lens_host_{seq_lens_host.has_value()}, + has_block_table_host_{block_table_host.has_value()} { + assert( + num_heads % num_kv_heads == 0 && + "`PagedAttention` requires `num_heads` divisible by `num_kv_heads`."); + assert(query.ndim() == 3 && + "`PagedAttention` requires query to be 3D [batch, num_heads, " + "head_size]."); + assert(key_cache.ndim() == 4 && + "`PagedAttention` requires key_cache to be 4D [num_blocks, " + "block_size, num_kv_heads, head_size]."); + assert(value_cache.ndim() == 4 && + "`PagedAttention` requires value_cache to be 4D [num_blocks, " + "block_size, num_kv_heads, head_size]."); + assert(key_cache.shape() == value_cache.shape() && + "`PagedAttention` requires key_cache and value_cache same shape."); + assert(query.dtype() == key_cache.dtype() && + query.dtype() == value_cache.dtype() && + query.dtype() == output.dtype() && + "`PagedAttention` requires query, caches, and output same dtype."); + assert(query.size(1) == static_cast(num_heads) && + "`PagedAttention` requires num_heads to match query shape."); + assert(key_cache.size(1) == static_cast(block_size) && + "`PagedAttention` requires block_size to match cache shape."); + assert(key_cache.size(2) == static_cast(num_kv_heads) && + "`PagedAttention` requires num_kv_heads to match cache shape."); + assert(query.size(2) == static_cast(head_size) && + key_cache.size(3) == static_cast(head_size) && + "`PagedAttention` requires head_size to match query and cache."); + assert(query.stride(-1) == 1 && key_cache.stride(-1) == 1 && + value_cache.stride(-1) == 1 && output.stride(-1) == 1 && + "`PagedAttention` requires contiguous last dimension."); + assert(output.shape() == query.shape() && + "`PagedAttention` requires output to match query shape."); + assert(seq_lens.ndim() == 1 && + "`PagedAttention` requires seq_lens to be 1D [batch]."); + assert(seq_lens.size(0) == batch_size_ && + "`PagedAttention` requires seq_lens batch to match query."); + assert(seq_lens.dtype() == DataType::kInt32 && + "`PagedAttention` requires seq_lens to be int32."); + assert(block_table.ndim() == 2 && + "`PagedAttention` requires block_table to be 2D [batch, " + "max_num_blocks]."); + assert(block_table.size(0) == batch_size_ && + "`PagedAttention` requires block_table batch to match query."); + assert(block_table.dtype() == DataType::kInt32 && + "`PagedAttention` requires block_table to be int32."); + + if (seq_lens_host.has_value()) { + assert(seq_lens_host->shape() == seq_lens.shape() && + "`PagedAttention` requires seq_lens_host to mirror seq_lens."); + assert(seq_lens_host->dtype() == seq_lens.dtype() && + "`PagedAttention` requires seq_lens_host dtype to match " + "seq_lens."); + assert(seq_lens_host->device().type() == Device::Type::kCpu && + "`PagedAttention` requires seq_lens_host to be on CPU."); + } + + if (block_table_host.has_value()) { + assert(block_table_host->shape() == block_table.shape() && + "`PagedAttention` requires block_table_host to mirror " + "block_table."); + assert(block_table_host->dtype() == block_table.dtype() && + "`PagedAttention` requires block_table_host dtype to match " + "block_table."); + assert(block_table_host->device().type() == Device::Type::kCpu && + "`PagedAttention` requires block_table_host to be on CPU."); + } + } + + virtual void operator()( + const Tensor query, const Tensor key_cache, const Tensor value_cache, + const Tensor seq_lens, const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, int64_t block_size, + Tensor output, std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_cache_shape_; + + Tensor::Shape value_cache_shape_; + + Tensor::Shape seq_lens_shape_; + + Tensor::Shape block_table_shape_; + + Tensor::Shape output_shape_; + + bool has_seq_lens_host_{false}; + + bool has_block_table_host_{false}; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_PAGED_ATTENTION_H_ diff --git a/src/native/ascend/ops/paged_attention/kernel_atb.h b/src/native/ascend/ops/paged_attention/kernel_atb.h new file mode 100644 index 000000000..dc59e1f3e --- /dev/null +++ b/src/native/ascend/ops/paged_attention/kernel_atb.h @@ -0,0 +1,283 @@ +#ifndef INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/paged_attention.h" +#include "native/ascend/atb_common_.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based paged decode attention (implementation index 0). +// +// Wraps ATB `PagedAttentionParam` with the default `inputLayout` +// (`TYPE_BSND`). For decode (single token per request) the S +// dimension is implicitly 1, so query and output use 3D shape +// [batch, num_heads, head_size] matching vLLM's convention. +// +// ATB internally constructs `aclIntArray*` from the `hostData` field +// of `block_table` and `context_lens` tensors. By default the operator +// performs synchronous D2H copies for these two small tensors each call. +// When the caller provides `seq_lens_host` and `block_table_host` (CPU +// pinned tensors), the D2H copies are skipped entirely — enabling full +// NPUGraph capture of the decode attention path. +// +// ATB VariantPack layout (BSND with S=1): +// inTensors[0] = query [B, N, D] +// inTensors[1] = key_cache [num_blocks, block_size, Nkv, D] +// inTensors[2] = value_cache [num_blocks, block_size, Nkv, D] +// inTensors[3] = block_table [B, max_num_blocks] (device + host) +// inTensors[4] = context_lens [B] (int32) (device + host) +// outTensors[0] = output [B, N, D] +template <> +class Operator + : public PagedAttention { + public: + Operator(const Tensor query, const Tensor key_cache, const Tensor value_cache, + const Tensor seq_lens, const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) + : PagedAttention(query, key_cache, value_cache, seq_lens, block_table, + num_heads, num_kv_heads, head_size, scale, block_size, + output, seq_lens_host, block_table_host) { + int64_t B = static_cast(batch_size_); + int64_t N = num_heads_; + int64_t Nkv = num_kv_heads_; + int64_t D = head_size_; + + // Query/output shapes: 3D [B, N, D] (BSND with S=1 for decode). + query_tnd_shape_ = {B, N, D}; + output_tnd_shape_ = {B, N, D}; + + // KV cache shapes. + int64_t num_blocks = static_cast(key_cache.size(0)); + int64_t bs = static_cast(key_cache.size(1)); + kv_cache_shape_ = {num_blocks, bs, Nkv, D}; + + // Block table and context lens shapes. + int64_t max_blocks = static_cast(block_table.size(1)); + block_table_shape_ = {B, max_blocks}; + context_lens_shape_ = {B}; + + // ACL data types. + acl_dt_ = ascend::ToAclDtype(query.dtype()); + bt_dt_ = ascend::ToAclDtype(block_table.dtype()); + sl_dt_ = ascend::ToAclDtype(seq_lens.dtype()); + + // Element sizes for `dataSize` computation. + elem_size_ = query.element_size(); + bt_elem_size_ = block_table.element_size(); + sl_elem_size_ = seq_lens.element_size(); + + // Pre-allocate pinned host buffers for D2H copies. + // ATB PA reads `hostData` from block_table and context_lens to + // construct internal `aclIntArray*` parameters. + // When caller provides host tensors, skip allocation — the caller's + // pinned buffers will be used directly in `operator()`. + bt_host_bytes_ = static_cast(B * max_blocks) * bt_elem_size_; + sl_host_bytes_ = static_cast(B) * sl_elem_size_; + + if (!has_block_table_host_) { + bt_host_ = std::malloc(bt_host_bytes_); + assert(bt_host_ && "Host buffer allocation for `block_table` failed"); + } + + if (!has_seq_lens_host_) { + sl_host_ = std::malloc(sl_host_bytes_); + assert(sl_host_ && "Host buffer allocation for `seq_lens` failed"); + } + + // Create the ATB operation (reused across calls). + atb::infer::PagedAttentionParam param; + param.headNum = static_cast(N); + param.kvHeadNum = static_cast(Nkv); + param.qkScale = static_cast(scale_); + + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && "atb::CreateOperation(PagedAttention) failed"); + } + + ~Operator() { + // Host memory is always safe to free. + if (!has_block_table_host_) { + std::free(bt_host_); + } + + if (!has_seq_lens_host_) { + std::free(sl_host_); + } + + if (!ascend::IsAclRuntimeAlive()) return; + + if (op_) { + atb::DestroyOperation(op_); + } + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host, + std::optional block_table_host) const override { + auto stream = static_cast(stream_); + atb::Context* ctx = ascend::GetAtbContext(stream); + + // Use caller-provided host data or perform synchronous D2H copy. + // ATB reads `hostData` to construct internal `aclIntArray*`. + void* bt_host_ptr = bt_host_; + void* sl_host_ptr = sl_host_; + + if (block_table_host.has_value()) { + bt_host_ptr = const_cast(block_table_host.value().data()); + } else { + aclrtMemcpy(bt_host_, bt_host_bytes_, block_table.data(), bt_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (seq_lens_host.has_value()) { + sl_host_ptr = const_cast(seq_lens_host.value().data()); + } else { + aclrtMemcpy(sl_host_, sl_host_bytes_, seq_lens.data(), sl_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + } + + atb::VariantPack vp = buildVariantPack( + const_cast(query.data()), const_cast(key_cache.data()), + const_cast(value_cache.data()), + const_cast(block_table.data()), + const_cast(seq_lens.data()), output.data(), bt_host_ptr, + sl_host_ptr); + + // Setup computes workspace requirements and binds tensor descriptors. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(PagedAttention) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(PagedAttention) failed"); + } + + private: + // Build the ATB VariantPack. + // + // Query and output are 3D [B, N, D] (BSND with S=1 for decode). + // Block table and context lens carry both `deviceData` and + // `hostData` because ATB reads the host copy to build internal + // `aclIntArray*` parameters. + atb::VariantPack buildVariantPack(void* query_data, void* key_cache_data, + void* value_cache_data, + void* block_table_data, void* seq_lens_data, + void* output_data, void* bt_host_ptr, + void* sl_host_ptr) const { + int64_t B = query_tnd_shape_[0]; + int64_t N = query_tnd_shape_[1]; + int64_t D = query_tnd_shape_[2]; + + // Query [B, N, D] — 3D (BSND with S=1). + uint64_t q_bytes = static_cast(B * N * D) * elem_size_; + atb::Tensor t_query = + ascend::ToAtbTensor(query_tnd_shape_, acl_dt_, query_data, q_bytes); + + // KV caches [num_blocks, block_size, Nkv, D]. + int64_t nb = kv_cache_shape_[0]; + int64_t bs = kv_cache_shape_[1]; + int64_t Nkv = kv_cache_shape_[2]; + uint64_t kv_bytes = static_cast(nb * bs * Nkv * D) * elem_size_; + atb::Tensor t_key_cache = + ascend::ToAtbTensor(kv_cache_shape_, acl_dt_, key_cache_data, kv_bytes); + atb::Tensor t_value_cache = ascend::ToAtbTensor(kv_cache_shape_, acl_dt_, + value_cache_data, kv_bytes); + + // Block table [B, max_blocks] — with hostData for `aclIntArray*`. + atb::Tensor t_block_table = ascend::ToAtbTensor( + block_table_shape_, bt_dt_, block_table_data, bt_host_bytes_); + t_block_table.hostData = bt_host_ptr; + + // Context lens [B] — with hostData for `aclIntArray*`. + atb::Tensor t_context_lens = ascend::ToAtbTensor( + context_lens_shape_, sl_dt_, seq_lens_data, sl_host_bytes_); + t_context_lens.hostData = sl_host_ptr; + + // Output [B, N, D] — 3D (BSND with S=1). + atb::Tensor t_output = + ascend::ToAtbTensor(output_tnd_shape_, acl_dt_, output_data, q_bytes); + + atb::VariantPack vp; + vp.inTensors = {t_query, t_key_cache, t_value_cache, t_block_table, + t_context_lens}; + vp.outTensors = {t_output}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector query_tnd_shape_; + + std::vector output_tnd_shape_; + + std::vector kv_cache_shape_; + + std::vector block_table_shape_; + + std::vector context_lens_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + aclDataType bt_dt_ = ACL_DT_UNDEFINED; + + aclDataType sl_dt_ = ACL_DT_UNDEFINED; + + uint64_t elem_size_ = 0; + + uint64_t bt_elem_size_ = 0; + + uint64_t sl_elem_size_ = 0; + + // Host-side buffers for ATB's internal `aclIntArray*` construction. + void* bt_host_ = nullptr; + + void* sl_host_ = nullptr; + + uint64_t bt_host_bytes_ = 0; + + uint64_t sl_host_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py new file mode 100644 index 000000000..c2258ffaf --- /dev/null +++ b/tests/test_paged_attention.py @@ -0,0 +1,554 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +def _atb_pa_unsupported_reason(): + """Return a reason string if ATB PagedAttention can't run here, else `""`. + + Uses a narrow SoC-name check rather than a try/except on the op under + test — the latter silently masks real regressions by converting any + runtime failure in `paged_attention` into a clean skip. + """ + if not (hasattr(torch, "npu") and torch.npu.is_available()): + return "NPU not available" + + if not infini.ops.PagedAttention.active_implementation_indices("ascend"): + return "ATB PagedAttention implementation not registered for Ascend" + + return "" + + +_skip_no_atb_pa = pytest.mark.skipif( + bool(_atb_pa_unsupported_reason()), + reason=_atb_pa_unsupported_reason() or "ATB PagedAttention unsupported", +) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + (32, 32, 128, 128), # MHA + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_basic( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Basic paged decode attention with contiguous block assignments.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 4 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + # Context lengths (total KV length per request). + seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_variable_seq_lens( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Paged decode attention where each request has a different KV length.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + kv_lens = [8, 32, 16, 128] + num_reqs = len(kv_lens) + max_blocks_per_req = max((kv + block_size - 1) // block_size for kv in kv_lens) + num_blocks = sum((kv + block_size - 1) // block_size for kv in kv_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: assign blocks sequentially. + block_table = torch.zeros( + (num_reqs, max_blocks_per_req), dtype=torch.int32, device=device + ) + block_idx = 0 + + for i in range(num_reqs): + n_blocks = (kv_lens[i] + block_size - 1) // block_size + + for j in range(n_blocks): + block_table[i, j] = block_idx + block_idx += 1 + + seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_single_request( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Single request decode (batch_size=1).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 1 + kv_len = 64 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.arange( + num_blocks_per_req, dtype=torch.int32, device=device + ).unsqueeze(0) + + seq_lens = torch.tensor([kv_len], dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_host_tensors( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Paged decode with caller-provided host tensors (D2H-free path).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 4 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device) + + # CPU copies for the D2H-free path. + seq_lens_cpu = seq_lens.cpu().contiguous() + block_table_cpu = block_table.cpu().contiguous() + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention_with_host( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + seq_lens_cpu, + block_table_cpu, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _paged_attention_with_host( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + seq_lens_host, + block_table_host, +): + """Call paged attention with caller-provided host tensors.""" + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + seq_lens_host=seq_lens_host, + block_table_host=block_table_host, + stream=get_stream(query.device), + ) + + return output + + +def _paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, +): + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + stream=get_stream(query.device), + ) + + return output + + +def _ref_paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, +): + """PyTorch SDPA reference for paged decode attention.""" + sl = seq_lens.cpu() + bt = block_table.cpu() + kc = key_cache.cpu().float() + vc = value_cache.cpu().float() + q_cpu = query.cpu().float() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(sl[i].item()) + + # Gather K and V from paged cache. + # Cache layout: [num_blocks, block_size, Nkv, D]. + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + + for b in blocks: + if remaining <= 0: + break + + take = min(remaining, block_size) + k_pages.append(kc[int(b.item()), :take, :, :]) + v_pages.append(vc[int(b.item()), :take, :, :]) + remaining -= take + + # [kv_len, Nkv, D] + k = torch.cat(k_pages, dim=0) + v = torch.cat(v_pages, dim=0) + + # SDPA reference with GQA expansion. + # q: [1, N, D] -> [N, 1, D] + q_t = q.transpose(0, 1) + # k, v: [kv_len, Nkv, D] -> [Nkv, kv_len, D] + k_t = k.transpose(0, 1) + v_t = v.transpose(0, 1) + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(ratio, dim=0) + v_t = v_t.repeat_interleave(ratio, dim=0) + + # [N, 1, D] and [N, kv_len, D] -> [1, N, 1, D] and [1, N, kv_len, D] + q_4d = q_t.unsqueeze(0) + k_4d = k_t.unsqueeze(0) + v_4d = v_t.unsqueeze(0) + + # Decode: query attends to all past KV (no causal mask). + out = torch.nn.functional.scaled_dot_product_attention( + q_4d, + k_4d, + v_4d, + scale=scale, + is_causal=False, + ) + + # [1, N, 1, D] -> [1, N, D] + outputs.append(out.squeeze(0).transpose(0, 1).squeeze(0).unsqueeze(0)) + + return torch.cat(outputs, dim=0).to(query.dtype).to(query.device)