diff --git a/src/native/ascend/ops/reshape_and_cache/kernel.h b/src/native/ascend/ops/reshape_and_cache/kernel.h new file mode 100644 index 000000000..b6a016c74 --- /dev/null +++ b/src/native/ascend/ops/reshape_and_cache/kernel.h @@ -0,0 +1,109 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_index_copy.h" +#include "base/reshape_and_cache.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Device-side scatter via aclnnInplaceIndexCopy. +// +// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream), +// then issued per-token D2D memcpy in a host loop. For batch=256, this meant +// ~100 us sync + ~500 us host loop overhead. aclnnInplaceIndexCopy performs +// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V), +// eliminating all D2H synchronisation and host-side loops. +// +// Requirement: slot_mapping must contain only non-negative values. Padding +// tokens (slot < 0) must be filtered by the caller before invoking this +// operator. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t total_slots = num_blocks * bs; + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::ToAclDtype(key.dtype()); + + // Flattened K cache view: [total_slots, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({total_slots, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + + // K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0. + // Executor caching is not used here because aclnnInplaceIndexCopy is an + // inplace operation where self is both input and output; the executor + // reuse via aclSetInputTensorAddr does not update the output reference. + uint64_t k_ws = 0; + aclOpExecutor* k_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, &k_ws, + &k_exec); + auto& k_arena = ascend::GetWorkspacePool().Ensure(stream, k_ws); + aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream); + + // V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0. + uint64_t v_ws = 0; + aclOpExecutor* v_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, &v_ws, + &v_exec); + auto& v_arena = ascend::GetWorkspacePool().Ensure(stream, v_ws); + aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/ascend/ops/reshape_and_cache/kernel_atb.h b/src/native/ascend/ops/reshape_and_cache/kernel_atb.h new file mode 100644 index 000000000..2b3c02207 --- /dev/null +++ b/src/native/ascend/ops/reshape_and_cache/kernel_atb.h @@ -0,0 +1,257 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/reshape_and_cache.h" +#include "native/ascend/atb_common_.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based KV cache scatter via `atb::infer::ReshapeAndCacheParam` +// (implementation index 2). +// +// Handles both K and V in a single fused operation. Profiled at ~9.5 us/call +// on Ascend 910B (256 tokens, fp16) — 3.7x faster than the +// `aclnnInplaceIndexCopy` path (index 0, ~35 us). +// +// The ATB operation is created once in the constructor. Setup is called +// before each Execute to bind the VariantPack. +// +// NOTE: `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the +// caller passes int64 (the PyTorch / vLLM default), this operator issues an +// async `aclnnCast` to a pre-allocated int32 device buffer. The cast +// executor is cached across calls and the whole step stays on the stream +// with no D2H/H2D round-trip, so the int64 path is NPUGraph-capturable and +// roughly on par with the int32 fast path. +// +// Input layout: +// key, value : [num_tokens, num_kv_heads, head_size] +// slot_mapping: [num_tokens] (int32 or int64) +// +// KV cache layout: +// kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] +// Output key_cache = kv_cache[0], value_cache = kv_cache[1], each with +// shape [num_blocks, block_size, num_kv_heads, head_size]. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + int64_t T = static_cast(num_tokens_); + + // Cache shapes for rebuilding VariantPack on each call. + kv_shape_ = {num_blocks, bs, nkv, hs}; + key_shape_ = {T, nkv, hs}; + slot_shape_ = {T}; + acl_dt_ = ascend::ToAclDtype(key.dtype()); + + // Compute V-cache byte offset (kv_cache_out[1]). + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + + // Element sizes for dataSize computation. + elem_size_ = key.element_size(); + + // Pre-allocate int32 device buffer for `slot_mapping`. + // `ReshapeAndCacheParam` requires int32; int64 is silently ignored + // (writes nothing). + slot32_bytes_ = static_cast(T) * sizeof(int32_t); + aclrtMalloc(&slot32_buf_, slot32_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(slot32_buf_ && "aclrtMalloc for slot32_buf_ failed"); + + slot_is_int32_ = (slot_mapping.element_size() == sizeof(int32_t)); + + // Prepare aclnnCast descriptors for the int64 → int32 path. Source + // descriptor's data pointer is refreshed per call; destination is the + // pre-allocated `slot32_buf_`. + if (!slot_is_int32_) { + slot_i64_cache_ = ascend::AclTensorCache( + {T}, ACL_INT64, const_cast(slot_mapping.data())); + slot_i32_cache_ = ascend::AclTensorCache({T}, ACL_INT32, slot32_buf_); + } + + // Create the ATB operation (reused across calls). + atb::infer::ReshapeAndCacheParam param; + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && + "atb::CreateOperation(ReshapeAndCache) failed"); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + if (op_) atb::DestroyOperation(op_); + slot_i64_cache_.release(); + slot_i32_cache_.release(); + if (slot32_buf_) aclrtFree(slot32_buf_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + // `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the + // caller provides int64 (the PyTorch/vLLM default), issue an async + // `aclnnCast` to the pre-allocated int32 device buffer — keeps the + // whole step on-stream and NPUGraph-capturable. + void* slot32_ptr; + + if (slot_is_int32_) { + // Already int32 — pass through directly. + slot32_ptr = const_cast(slot_mapping.data()); + } else { + auto t_src = slot_i64_cache_.get(const_cast(slot_mapping.data())); + auto t_dst = slot_i32_cache_.get(slot32_buf_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_INT32, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(slot_mapping.data())); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, slot32_buf_); + } + + auto& cast_arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_); + aclnnCast(cast_arena.buf, cast_ws_, cast_exec_, stream); + slot32_ptr = slot32_buf_; + } + + atb::Context* ctx = ascend::GetAtbContext(stream); + + atb::VariantPack vp = buildVariantPack(const_cast(key.data()), + const_cast(value.data()), + kv_cache_out.data(), slot32_ptr); + + // Setup binds the VariantPack and computes workspace requirements. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(ReshapeAndCache) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(ReshapeAndCache) failed"); + } + + private: + // Build the ATB VariantPack for this operation. + // + // ATB `ReshapeAndCache` expects 5 inputs and 2 outputs: + // inTensors[0] = key [num_tokens, num_kv_heads, head_size] + // inTensors[1] = value [num_tokens, num_kv_heads, head_size] + // inTensors[2] = key_cache [num_blocks, block_size, num_kv_heads, + // head_size] inTensors[3] = value_cache [num_blocks, block_size, + // num_kv_heads, head_size] inTensors[4] = slot_mapping [num_tokens] (int32) + // outTensors[0] = key_cache (same buffer, in-place) + // outTensors[1] = value_cache (same buffer, in-place) + atb::VariantPack buildVariantPack(void* key_data, void* value_data, + void* kv_out_data, + void* slot32_data) const { + int64_t num_tokens = key_shape_[0]; + int64_t nkv = key_shape_[1]; + int64_t hs = key_shape_[2]; + uint64_t kv_bytes = + static_cast(num_tokens * nkv * hs) * elem_size_; + + int64_t nb = kv_shape_[0]; + int64_t bs = kv_shape_[1]; + uint64_t cache_bytes = + static_cast(nb * bs * nkv * hs) * elem_size_; + + void* v_out_data = static_cast(kv_out_data) + v_offset_bytes_; + + atb::Tensor t_key = + ascend::ToAtbTensor(key_shape_, acl_dt_, key_data, kv_bytes); + + atb::Tensor t_value = + ascend::ToAtbTensor(key_shape_, acl_dt_, value_data, kv_bytes); + + atb::Tensor t_kv_k = + ascend::ToAtbTensor(kv_shape_, acl_dt_, kv_out_data, cache_bytes); + + atb::Tensor t_kv_v = + ascend::ToAtbTensor(kv_shape_, acl_dt_, v_out_data, cache_bytes); + + // Always int32 — the caller's `operator()` has already cast to int32. + atb::Tensor t_slot = + ascend::ToAtbTensor(slot_shape_, ACL_INT32, slot32_data, slot32_bytes_); + + atb::VariantPack vp; + vp.inTensors = {t_key, t_value, t_kv_k, t_kv_v, t_slot}; + vp.outTensors = {t_kv_k, t_kv_v}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector kv_shape_; + + std::vector key_shape_; + + std::vector slot_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + size_t v_offset_bytes_ = 0; + + uint64_t elem_size_ = 0; + + // Pre-allocated int32 device buffer for `slot_mapping`. + void* slot32_buf_ = nullptr; + + size_t slot32_bytes_ = 0; + + // True if the caller already provides int32 `slot_mapping`. + bool slot_is_int32_ = false; + + // Cached aclnnCast descriptors (int64 slot_mapping → int32 buffer). + mutable ascend::AclTensorCache slot_i64_cache_; + + mutable ascend::AclTensorCache slot_i32_cache_; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ diff --git a/src/native/ascend/ops/reshape_and_cache/kernel_v2.h b/src/native/ascend/ops/reshape_and_cache/kernel_v2.h new file mode 100644 index 000000000..684506b27 --- /dev/null +++ b/src/native/ascend/ops/reshape_and_cache/kernel_v2.h @@ -0,0 +1,123 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ + +// WARNING: This implementation is experimental and has strict hardware limits. +// +// Limitations: +// 1. Requires CANN 8.5.1+ (`aclnnScatterPaKvCache` API). +// 2. Only supported on Atlas A5 hardware (SoC 260). NOT supported on +// A2 (Ascend 910B, SoC 220-225) or A3 (SoC 250-255). +// 3. Not yet validated in production workloads. +// +// On unsupported hardware this file compiles to nothing (guarded by +// `__has_include`). Use `implementation_index=0` (the default +// `aclnnInplaceIndexCopy` path) for general-purpose deployment. + +#if __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_scatter_pa_kv_cache.h" +#include "base/reshape_and_cache.h" +#include "native/ascend/common.h" +#include "native/ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Fused KV cache scatter via `aclnnScatterPaKvCache` (implementation index 1). +// +// Handles both K and V scatter in a single CANN kernel launch, replacing two +// separate `aclnnInplaceIndexCopy` calls (index 0). The fused API is +// purpose-built for paged KV cache and avoids the internal decomposition to +// `ScatterElementsV2`. +// +// Requirements: +// - CANN 8.5.1+ (`aclnnop/aclnn_scatter_pa_kv_cache.h`). +// - Atlas A5 hardware (SoC 260). The API is NOT supported on A2 (910B, +// SoC 220-225) or A3 (SoC 250-255). +// +// Select via `implementation_index=1` in Python: +// infini.ops.reshape_and_cache(..., implementation_index=1, stream=s) +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::ToAclDtype(key.dtype()); + + // 4D K cache view: [num_blocks, block_size, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({num_blocks, bs, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {num_blocks, bs, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + + // Single fused scatter for both K and V caches. + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnScatterPaKvCacheGetWorkspaceSize( + t_key, t_kv_k, t_slot, t_value, t_kv_v, + /*compressLensOptional=*/nullptr, + /*compressSeqOffsetOptional=*/nullptr, + /*seqLensOptional=*/nullptr, + /*cacheModeOptional=*/nullptr, + /*scatterModeOptional=*/nullptr, + /*stridesOptional=*/nullptr, + /*offsetsOptional=*/nullptr, &ws, &exec); + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws); + aclnnScatterPaKvCache(arena.buf, ws, exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 000000000..4f69501f6 --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,273 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + +# ReshapeAndCache only works on NPU (aclrtMemcpy-based), so tests only +# parametrize on float16/bfloat16 and use explicit device parametrization. + +# `aclnnScatterPaKvCache` (index 1) requires Atlas A5 (SoC 260). It compiles +# on 910B (CANN 8.5.1 headers present) but produces wrong results at runtime. +_SKIP_INDEX_1 = pytest.mark.skip( + reason="`aclnnScatterPaKvCache` (index 1) requires Atlas A5; " + "not supported on Ascend 910B" +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 8, 128, 4, 16), + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + (16, 2, 128, 8, 64), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_contiguous( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + # Layout: [2, num_blocks, block_size, num_kv_heads, head_size] + # Index 0 = key cache, index 1 = value cache. + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Contiguous slot mapping: token i -> slot i. + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_noncontiguous_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Non-contiguous slots: skip every other slot. + slot_mapping = torch.tensor( + [i * 2 for i in range(num_tokens)], dtype=torch.int64, device=device + ) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (8, 8, 128, 4, 16), + (16, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_padding_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + """Graph-padded decode: slots with `-1` must be skipped, not written. + + `aclnnInplaceIndexCopy` silently treats `slot=-1` as "last index" which + corrupts the last KV cache entry. The wrapper must filter `-1` slots + before calling the underlying op. + """ + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + + # Every other token is a padding slot (`-1`); valid slots map to unique + # contiguous positions so a correct wrapper leaves the final entry of + # the last block untouched. + slot_values = [] + valid = 0 + + for i in range(num_tokens): + if i % 2 == 0: + slot_values.append(-1) + else: + slot_values.append(valid) + valid += 1 + + slot_mapping = torch.tensor(slot_values, dtype=torch.int64, device=device) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +def _reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out, implementation_index=0 +): + infini.ops.reshape_and_cache( + key, + value, + kv_cache, + slot_mapping, + kv_cache_out, + implementation_index=implementation_index, + stream=get_stream(key.device), + ) + + return kv_cache_out + + +def _ref_reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + kv_cache_out = kv_cache_out.clone() + slots = slot_mapping.cpu() + block_size = kv_cache_out.size(2) + + for i in range(key.size(0)): + slot = int(slots[i].item()) + + if slot < 0: + continue + + block_idx = slot // block_size + offset = slot % block_size + kv_cache_out[0, block_idx, offset, :, :] = key[i, :, :] + kv_cache_out[1, block_idx, offset, :, :] = value[i, :, :] + + return kv_cache_out