Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions src/base/paged_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_H_
#define INFINI_OPS_BASE_PAGED_ATTENTION_H_

#include <cstddef>
#include <cstdint>
#include <optional>

#include "data_type.h"
#include "operator.h"

namespace infini::ops {

// Paged decode attention operator.
//
// Performs multi-head attention over paged KV caches for decode (single-token
// queries per sequence).
//
// Interface follows vLLM's paged attention convention:
// - vLLM CUDA: `torch.ops.vllm.paged_attention_v1` uses the same query
// shape [batch, num_heads, head_size] and seq_lens [batch] int32.
// KV cache differs (5D on CUDA for vectorization, 4D here).
// - vLLM-Ascend: `torch_npu._npu_paged_attention` wraps ATB
// `PagedAttentionParam` with default `inputLayout` (`TYPE_BSND`).
// - ATB `PagedAttentionParam`: `headNum`, `kvHeadNum`, `qkScale`,
// `maskType` (default NORM), `inputLayout` (default `TYPE_BSND`).
//
// Input layout (BSND with S=1 for decode):
// query : [batch, num_heads, head_size]
// key_cache : [num_blocks, block_size, num_kv_heads, head_size]
// value_cache : [num_blocks, block_size, num_kv_heads, head_size]
// seq_lens : [batch] int32 — total context length per sequence
// block_table : [batch, max_num_blocks_per_seq] int32
//
// Output layout:
// output : [batch, num_heads, head_size]
//
// Optional host tensors: `seq_lens_host` and `block_table_host` are CPU
// mirrors of `seq_lens` and `block_table`. They exist because CANN's
// paged-attention APIs mandate CPU-resident metadata — aclnn declares
// `qSeqLens` as a CPU tensor in its signature, and ATB
// `PagedAttentionParam` reads `aclIntArray*` parameters from the
// `hostData` field at `aclnnRunner::Setup()` time. Without caller-
// provided host tensors, the kernel must synchronously D2H-copy both
// each call, which (a) blocks the stream and (b) prevents NPUGraph
// capture (sync copies are not capturable). When the caller already
// has CPU-pinned copies (e.g. vLLM's `optimistic_seq_lens_cpu` and
// `BlockTable.get_cpu_tensor()`), passing them through lets the kernel
// skip both D2H copies and be captured into a full NPUGraph.
class PagedAttention : public Operator<PagedAttention> {
public:
// Paged attention follows vLLM naming. `output` is explicit because
// InfiniOps operators are in-place; it remains before optional host mirrors
// to preserve the existing call surface.
PagedAttention(const Tensor query, const Tensor key_cache,
const Tensor value_cache, const Tensor seq_lens,
const Tensor block_table, int64_t num_heads,
int64_t num_kv_heads, int64_t head_size, double scale,
int64_t block_size, Tensor output,
std::optional<Tensor> seq_lens_host = std::nullopt,
std::optional<Tensor> block_table_host = std::nullopt)
: batch_size_{query.size(0)},
num_heads_{num_heads},
num_kv_heads_{num_kv_heads},
head_size_{head_size},
scale_{scale},
block_size_{block_size},
dtype_{query.dtype()},
query_shape_{query.shape()},
key_cache_shape_{key_cache.shape()},
value_cache_shape_{value_cache.shape()},
seq_lens_shape_{seq_lens.shape()},
block_table_shape_{block_table.shape()},
output_shape_{output.shape()},
has_seq_lens_host_{seq_lens_host.has_value()},
has_block_table_host_{block_table_host.has_value()} {
assert(
num_heads % num_kv_heads == 0 &&
"`PagedAttention` requires `num_heads` divisible by `num_kv_heads`.");
assert(query.ndim() == 3 &&
"`PagedAttention` requires query to be 3D [batch, num_heads, "
"head_size].");
assert(key_cache.ndim() == 4 &&
"`PagedAttention` requires key_cache to be 4D [num_blocks, "
"block_size, num_kv_heads, head_size].");
assert(value_cache.ndim() == 4 &&
"`PagedAttention` requires value_cache to be 4D [num_blocks, "
"block_size, num_kv_heads, head_size].");
assert(key_cache.shape() == value_cache.shape() &&
"`PagedAttention` requires key_cache and value_cache same shape.");
assert(query.dtype() == key_cache.dtype() &&
query.dtype() == value_cache.dtype() &&
query.dtype() == output.dtype() &&
"`PagedAttention` requires query, caches, and output same dtype.");
assert(query.size(1) == static_cast<Tensor::Size>(num_heads) &&
"`PagedAttention` requires num_heads to match query shape.");
assert(key_cache.size(1) == static_cast<Tensor::Size>(block_size) &&
"`PagedAttention` requires block_size to match cache shape.");
assert(key_cache.size(2) == static_cast<Tensor::Size>(num_kv_heads) &&
"`PagedAttention` requires num_kv_heads to match cache shape.");
assert(query.size(2) == static_cast<Tensor::Size>(head_size) &&
key_cache.size(3) == static_cast<Tensor::Size>(head_size) &&
"`PagedAttention` requires head_size to match query and cache.");
assert(query.stride(-1) == 1 && key_cache.stride(-1) == 1 &&
value_cache.stride(-1) == 1 && output.stride(-1) == 1 &&
"`PagedAttention` requires contiguous last dimension.");
assert(output.shape() == query.shape() &&
"`PagedAttention` requires output to match query shape.");
assert(seq_lens.ndim() == 1 &&
"`PagedAttention` requires seq_lens to be 1D [batch].");
assert(seq_lens.size(0) == batch_size_ &&
"`PagedAttention` requires seq_lens batch to match query.");
assert(seq_lens.dtype() == DataType::kInt32 &&
"`PagedAttention` requires seq_lens to be int32.");
assert(block_table.ndim() == 2 &&
"`PagedAttention` requires block_table to be 2D [batch, "
"max_num_blocks].");
assert(block_table.size(0) == batch_size_ &&
"`PagedAttention` requires block_table batch to match query.");
assert(block_table.dtype() == DataType::kInt32 &&
"`PagedAttention` requires block_table to be int32.");

if (seq_lens_host.has_value()) {
assert(seq_lens_host->shape() == seq_lens.shape() &&
"`PagedAttention` requires seq_lens_host to mirror seq_lens.");
assert(seq_lens_host->dtype() == seq_lens.dtype() &&
"`PagedAttention` requires seq_lens_host dtype to match "
"seq_lens.");
assert(seq_lens_host->device().type() == Device::Type::kCpu &&
"`PagedAttention` requires seq_lens_host to be on CPU.");
}

if (block_table_host.has_value()) {
assert(block_table_host->shape() == block_table.shape() &&
"`PagedAttention` requires block_table_host to mirror "
"block_table.");
assert(block_table_host->dtype() == block_table.dtype() &&
"`PagedAttention` requires block_table_host dtype to match "
"block_table.");
assert(block_table_host->device().type() == Device::Type::kCpu &&
"`PagedAttention` requires block_table_host to be on CPU.");
}
}

virtual void operator()(
const Tensor query, const Tensor key_cache, const Tensor value_cache,
const Tensor seq_lens, const Tensor block_table, int64_t num_heads,
int64_t num_kv_heads, int64_t head_size, double scale, int64_t block_size,
Tensor output, std::optional<Tensor> seq_lens_host = std::nullopt,
std::optional<Tensor> block_table_host = std::nullopt) const = 0;

protected:
Tensor::Size batch_size_{0};

int64_t num_heads_{0};

int64_t num_kv_heads_{0};

int64_t head_size_{0};

double scale_{0.0};

int64_t block_size_{0};

const DataType dtype_;

Tensor::Shape query_shape_;

Tensor::Shape key_cache_shape_;

Tensor::Shape value_cache_shape_;

Tensor::Shape seq_lens_shape_;

Tensor::Shape block_table_shape_;

Tensor::Shape output_shape_;

bool has_seq_lens_host_{false};

bool has_block_table_host_{false};
};

} // namespace infini::ops

#endif // INFINI_OPS_BASE_PAGED_ATTENTION_H_
Loading
Loading