From 6077feade5f2689cf76a62a17e331956576d4c92 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 22 Feb 2026 23:59:15 +0000 Subject: [PATCH 1/7] Add int8 quantization for vortex. Key changes: 1. Memory Pool (`vtx_graph_memory_pool.py`): - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations. - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout. - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical. 2. Quantize-on-Write (`set_kv.py`): - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`). - Wired the new launcher into the cache update flow. 3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`): - Bypassed FlashInfer for INT8 decoding. - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers. - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`). 4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`): - Implemented an OOM-safe `bf16` fallback for prefill. - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer. - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs. --- CLAUDE.md | 88 +++++ examples/verify_algo.py | 27 +- examples/verify_algo.sh | 27 +- examples/verify_algo_quant.sh | 25 ++ vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 11 +- .../cache/triton_kernels/paged_decode_int8.py | 355 ++++++++++++++++++ .../triton_kernels/paged_prefill_int8.py | 90 +++++ vortex_torch/cache/triton_kernels/set_kv.py | 87 +++++ 9 files changed, 696 insertions(+), 17 deletions(-) create mode 100644 CLAUDE.md create mode 100644 examples/verify_algo_quant.sh create mode 100644 vortex_torch/cache/triton_kernels/paged_decode_int8.py create mode 100644 vortex_torch/cache/triton_kernels/paged_prefill_int8.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..db54c75 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,88 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. + +## Build & Install + +```bash +# Install SGLang dependency (custom fork in third_party/) +cd third_party/sglang && bash install.sh && cd ../../ + +# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) +pip install -e . +``` + +Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). + +## Running Examples + +```bash +# Single algorithm verification against SGLang +python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention + +# Batch test multiple algorithms +bash examples/verify_algo.sh +``` + +## Building Documentation + +```bash +make -C docs html +``` + +Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. + +## Architecture + +### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) + +All sparse attention algorithms inherit from `vFlow` and implement three methods: + +- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. +- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. +- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. + +Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. + +### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) + +Operators (`vOp` subclasses) run in two modes: +- **Profile mode**: Pre-compute output shapes and allocate buffers +- **Execute mode**: Perform actual GPU computation + +Operators are split into two parallel hierarchies: +- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load +- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup + +Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. + +### Tensor Format (`vortex_torch/abs/tensor.py`) + +`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. + +### Context System (`vortex_torch/abs/context_base.py`) + +`ContextBase` carries per-step runtime state. Specialized as: +- `Indexer.Context`: Page layout, head config, hardware info +- `Cache.Context`: Page size, total pages, model info + +### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) + +- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) +- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation +- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds + +### SGLang Integration + +Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. + +## Key Conventions + +- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` +- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) +- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` +- **Branch**: Main development is on `v1` diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81..f418598 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -54,7 +54,8 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 +mem: float = 0.8, +kv_cache_dtype: str = "auto", ): llm = sgl.Engine(model_path=model_name, @@ -69,10 +70,11 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, ) - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: + with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] requests = requests * trials @@ -110,6 +112,14 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) + # --- Per-question debug output --- + print(f"[Q{len(results):03d}] score={float(result):.1f} " + f"tokens={item['meta_info']['completion_tokens']} " + f"latency={item['meta_info']['e2e_latency']:.2f}s " + f"gold={golds[0]}") + print(f" question: {data['question'][:120]}...") + print(f" prediction: {predictions[:200]}...") + print() total_accuracy = 0.0 @@ -203,6 +213,14 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) return parser.parse_args() if __name__ == "__main__": @@ -215,7 +233,8 @@ def parse_args(): vortex_module_name=args.vortex_module_name, model_name=args.model_name, sparse_attention=not(args.full_attention), - mem=args.mem + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5e..d80f09a 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,24 @@ #!/usr/bin/env bash set -e +export CUDA_VISIBLE_DEVICES=1 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh new file mode 100644 index 0000000..4cf1366 --- /dev/null +++ b/examples/verify_algo_quant.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=2 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa46..b886559 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,12 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfc..2d6384f 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,11 @@ -from .set_kv import set_kv_buffer_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .paged_decode_int8 import paged_decode_int8 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16 -__all__ = ["set_kv_buffer_launcher"] +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "paged_decode_int8", + "dequant_paged_int8_to_bf16", +] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py new file mode 100644 index 0000000..480c787 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -0,0 +1,355 @@ +""" +Custom Triton paged decode attention kernel for int8 KV cache. + +Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, +and computes standard multi-head attention with online softmax. + +Adapted from SGLang's decode_attention.py for use with Vortex's paged layout +where each KV head is treated as a separate "batch" entry. +""" + +import torch +import triton +import triton.language as tl + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_int8_stage1( + Q, # [batch, num_qo_heads, head_dim] bf16 + K_Buffer, # int8 paged: flat + V_Buffer, # int8 paged: flat + K_Scale_Buffer, # float32: flat (one scale per token slot) + V_Scale_Buffer, # float32: flat + sm_scale, + kv_indptr, # [batch + 1] int32, page-level + kv_indices, # page indices + last_page_len, # [batch] int32, tokens valid in last page + Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] + Att_Lse, # [batch, num_qo_heads, max_kv_splits] + num_kv_splits, # [batch] int32 + stride_qbs, + stride_qh, + stride_buf_kbs, # stride per token in K_Buffer (= head_dim) + stride_buf_vbs, # stride per token in V_Buffer (= head_dim) + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """ + Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. + + kv_indptr is page-level. Total tokens for batch i: + (num_pages - 1) * PAGE_SIZE + last_page_len[i] + """ + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + # Correct token count accounting for partial last page + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + # Convert token offsets to page_id + in-page offset + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + + # Load page indices from kv_indices (physical page IDs) + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, + other=0, + ) + + # Flat token location: physical_page * PAGE_SIZE + in_page_offset + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load int8 K and dequantize + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], + other=0, + ).to(tl.float32) + + k_scale = tl.load( + K_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + k = k_int8 * k_scale[:, None] + + # Compute QK + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load int8 V and dequantize + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], + other=0, + ).to(tl.float32) + + v_scale = tl.load( + V_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + v = v_int8 * v_scale[:, None] + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=mask_dv, + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +@triton.jit +def _fwd_kernel_int8_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode_int8( + q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 + k_buffer: torch.Tensor, # int8 paged K cache + v_buffer: torch.Tensor, # int8 paged V cache + k_scale_buffer: torch.Tensor, # float32 scale for K + v_scale_buffer: torch.Tensor, # float32 scale for V + o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output + kv_indptr: torch.Tensor, # [batch + 1] int32, page-level + kv_indices: torch.Tensor, # page indices + last_page_len: torch.Tensor, # [batch] int32 + num_kv_splits: torch.Tensor, # [batch] int32 + max_kv_splits: int, + sm_scale: float, + page_size: int, + logit_cap: float = 0.0, +): + """ + Paged decode attention with int8 KV cache and inline dequantization. + + kv_indptr is page-level. last_page_len specifies valid tokens in the last page + for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 64 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + + num_warps = 4 if kv_group_num == 1 else 2 + + # Intermediate buffers for split reduction + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_int8_stage1[grid_stage1]( + q, + k_buffer, + v_buffer, + k_scale_buffer, + v_scale_buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + att_out, + att_lse, + num_kv_splits, + q.stride(0), + q.stride(1), + stride_buf_kbs, + stride_buf_vbs, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + PAGE_SIZE=page_size, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_int8_stage2[grid_stage2]( + att_out, + att_lse, + o, + kv_indptr, + last_page_len, + num_kv_splits, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + o.stride(0), + o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py new file mode 100644 index 0000000..75c3857 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -0,0 +1,90 @@ +""" +OOM-safe bf16 fallback for int8 KV-cache prefill. + +Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, +this module dequantizes only the accessed KV pages into a compact temporary +bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. + +This avoids dequantizing the entire global cache buffer. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dequant_pages_kernel( + src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat + src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat + page_indices, # int32 [num_accessed_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16 compact buffer.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) + + # Scale: global_page_id * PAGE_SIZE + token_idx + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16( + src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] + src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + page_indices: torch.Tensor, # int32 [num_accessed_pages] + page_size: int, + head_dim: int, +) -> torch.Tensor: + """ + Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + + Returns: + bf16 tensor of shape [num_accessed_pages, page_size, head_dim] + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) + + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) + + return dst_bf16 diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab..4318428 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales: shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k) + tl.store(v_scale_cache + scale_offset, scale_v) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, From 1f52772d451ce0e8663784c1052475b4f7b2618f Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 03:19:56 +0000 Subject: [PATCH 2/7] 1. Add support for pro 6000. 2. Correction for vortex --- CLAUDE.md | 53 +++++++++++ setup.py | 5 +- third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/paged_decode_int8.py | 42 +++++---- .../triton_kernels/paged_prefill_int8.py | 94 +++++++++++++++++-- vortex_torch/cache/triton_kernels/set_kv.py | 10 +- 8 files changed, 178 insertions(+), 34 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index db54c75..1593d61 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -86,3 +86,56 @@ Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch) - **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) - **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` - **Branch**: Main development is on `v1` + +## Workflow Orchestration + +### 1. Plan Node Default +- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) +- If something goes sideways, STOP and re-plan immediately - don't keep pushing +- Use plan mode for verification steps, not just building +- Write detailed specs upfront to reduce ambiguity + +### 2. Subagent Strategy +- Use subagents liberally to keep main context window clean +- Offload research, exploration, and parallel analysis to subagents +- For complex problems, throw more compute at it via subagents +- One tack per subagent for focused execution + +### 3. Self-Improvement Loop +- After ANY correction from the user: update `tasks/lessons.md` with the pattern +- Write rules for yourself that prevent the same mistake +- Ruthlessly iterate on these lessons until mistake rate drops +- Review lessons at session start for relevant project + +### 4. Verification Before Done +- Never mark a task complete without proving it works +- Diff behavior between main and your changes when relevant +- Ask yourself: "Would a staff engineer approve this?" +- Run tests, check logs, demonstrate correctness + +### 5. Demand Elegance (Balanced) +- For non-trivial changes: pause and ask "is there a more elegant way?" +- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" +- Skip this for simple, obvious fixes - don't over-engineer +- Challenge your own work before presenting it + +### 6. Autonomous Bug Fixing +- When given a bug report: just fix it. Don't ask for hand-holding +- Point at logs, errors, failing tests - then resolve them +- Zero context switching required from the user +- Go fix failing CI tests without being told how + +## Task Management + +1. **Plan First**: Write plan to `tasks/todo.md` with checkable items +2. **Verify Plan**: Check in before starting implementation +3. **Track Progress**: Mark items complete as you go +4. **Explain Changes**: High-level summary at each step +5. **Document Results**: Add review section to `tasks/todo.md` +6. **Capture Lessons**: Update `tasks/lessons.md` after corrections + +## Core Principles + +- **Simplicity First**: Make every change as simple as possible. Impact minimal code. +- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. +- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/setup.py b/setup.py index e272326..6efeebe 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,11 @@ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0f..9672e9a 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b886559..b32d8bc 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 2d6384f..a18067e 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,11 +1,12 @@ from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher from .paged_decode_int8 import paged_decode_int8 -from .paged_prefill_int8 import dequant_paged_int8_to_bf16 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", + "dequant_paged_int8_to_bf16_inplace", ] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py index 480c787..4f33cd4 100644 --- a/vortex_torch/cache/triton_kernels/paged_decode_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -25,8 +25,8 @@ def _fwd_kernel_int8_stage1( Q, # [batch, num_qo_heads, head_dim] bf16 K_Buffer, # int8 paged: flat V_Buffer, # int8 paged: flat - K_Scale_Buffer, # float32: flat (one scale per token slot) - V_Scale_Buffer, # float32: flat + K_Scale_Buffer, # fp16: flat (one scale per token slot) + V_Scale_Buffer, # fp16: flat sm_scale, kv_indptr, # [batch + 1] int32, page-level kv_indices, # page indices @@ -118,7 +118,7 @@ def _fwd_kernel_int8_stage1( K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) k = k_int8 * k_scale[:, None] # Compute QK @@ -142,7 +142,7 @@ def _fwd_kernel_int8_stage1( V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) v = v_int8 * v_scale[:, None] # Online softmax accumulation @@ -251,8 +251,8 @@ def paged_decode_int8( q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 k_buffer: torch.Tensor, # int8 paged K cache v_buffer: torch.Tensor, # int8 paged V cache - k_scale_buffer: torch.Tensor, # float32 scale for K - v_scale_buffer: torch.Tensor, # float32 scale for V + k_scale_buffer: torch.Tensor, # fp16 scale for K + v_scale_buffer: torch.Tensor, # fp16 scale for V o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output kv_indptr: torch.Tensor, # [batch + 1] int32, page-level kv_indices: torch.Tensor, # page indices @@ -262,6 +262,8 @@ def paged_decode_int8( sm_scale: float, page_size: int, logit_cap: float = 0.0, + att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] + att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] ): """ Paged decode attention with int8 KV cache and inline dequantization. @@ -283,17 +285,23 @@ def paged_decode_int8( num_warps = 4 if kv_group_num == 1 else 2 - # Intermediate buffers for split reduction - att_out = torch.empty( - (batch, head_num, MAX_KV_SPLITS, Lv), - dtype=torch.float32, - device=q.device, - ) - att_lse = torch.empty( - (batch, head_num, MAX_KV_SPLITS), - dtype=torch.float32, - device=q.device, - ) + # Use pre-allocated buffers if provided, otherwise allocate + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + else: + att_lse = att_lse[:batch] stride_buf_kbs = k_buffer.shape[-1] stride_buf_vbs = v_buffer.shape[-1] diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py index 75c3857..8927983 100644 --- a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -16,7 +16,7 @@ @triton.jit def _dequant_pages_kernel( src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat - src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat page_indices, # int32 [num_accessed_pages] — which global pages to dequant NUM_PAGES: tl.constexpr, @@ -41,7 +41,7 @@ def _dequant_pages_kernel( # Scale: global_page_id * PAGE_SIZE + token_idx scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset) + scale = tl.load(src_scale + scale_offset).to(tl.float32) val_bf16 = (val_int8 * scale).to(tl.bfloat16) @@ -52,26 +52,35 @@ def _dequant_pages_kernel( def dequant_paged_int8_to_bf16( src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] - src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] page_indices: torch.Tensor, # int32 [num_accessed_pages] page_size: int, head_dim: int, + out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] ) -> torch.Tensor: """ Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + If `out` is provided, writes into it (must have room for num_accessed_pages). + Otherwise allocates a new buffer. + Returns: bf16 tensor of shape [num_accessed_pages, page_size, head_dim] """ num_accessed_pages = page_indices.shape[0] if num_accessed_pages == 0: + if out is not None: + return out[:0] return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) - dst_bf16 = torch.empty( - (num_accessed_pages, page_size, head_dim), - dtype=torch.bfloat16, - device=src_int8.device, - ) + if out is not None: + dst_bf16 = out[:num_accessed_pages] + else: + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) BLOCK_DIM = triton.next_power_of_2(head_dim) @@ -88,3 +97,72 @@ def dequant_paged_int8_to_bf16( ) return dst_bf16 + + +@triton.jit +def _dequant_pages_inplace_kernel( + src_int8, # int8 paged buffer flat + src_scale, # scale buffer flat (one scale per token slot) + dst_bf16, # bf16 destination buffer (same page layout as src) + page_indices, # int32 [num_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source and destination use the SAME offset (in-place layout) + offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) + + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset).to(tl.float32) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Write to the SAME page position in dst (not compacted) + tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16_inplace( + src_int8: torch.Tensor, # int8 paged cache (flat) + src_scale: torch.Tensor, # fp16 scale buffer (flat) + dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) + page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant + page_size: int, + head_dim: int, +) -> None: + """ + Dequantize selected pages from int8 cache to bf16 IN-PLACE. + + Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), + this writes to the SAME page positions in dst_bf16, preserving the paged layout. + Used to populate the bf16 working buffer for forward_cache (centroid computation). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_inplace_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 4318428..2a2c785 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -40,8 +40,8 @@ def set_kv_buffer_kernel( def set_kv_buffer_int8_kernel( k_cache, # int8 paged K cache v_cache, # int8 paged V cache - k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] - v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] loc, # int64 token positions @@ -87,11 +87,11 @@ def set_kv_buffer_int8_kernel( tl.store(dst_k_ptr, q_k) tl.store(dst_v_ptr, q_v) - # Write per-token scales: shape [num_pages, page_size, 1] + # Write per-token scales (fp16): shape [num_pages, page_size, 1] # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset - tl.store(k_scale_cache + scale_offset, scale_k) - tl.store(v_scale_cache + scale_offset, scale_v) + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) def set_kv_buffer_int8_launcher( From 584f23355412a4464215ccb15d06758ad1b2762c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 07:07:28 +0000 Subject: [PATCH 3/7] 1. Correction on int8 (maximize memory occupation) 2. Implement fp8 quantization. --- CLAUDE.md | 141 -------- examples/verify_algo.py | 14 +- examples/verify_algo_fp8.sh | 25 ++ ...rify_algo_quant.sh => verify_algo_int8.sh} | 0 third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/context.py | 20 +- vortex_torch/cache/reduce.py | 6 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/reduce_impl.py | 328 +++++++++--------- vortex_torch/cache/triton_kernels/set_kv.py | 97 +++++- 11 files changed, 319 insertions(+), 320 deletions(-) delete mode 100644 CLAUDE.md create mode 100755 examples/verify_algo_fp8.sh rename examples/{verify_algo_quant.sh => verify_algo_int8.sh} (100%) diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 1593d61..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,141 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. - -## Build & Install - -```bash -# Install SGLang dependency (custom fork in third_party/) -cd third_party/sglang && bash install.sh && cd ../../ - -# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) -pip install -e . -``` - -Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). - -## Running Examples - -```bash -# Single algorithm verification against SGLang -python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention - -# Batch test multiple algorithms -bash examples/verify_algo.sh -``` - -## Building Documentation - -```bash -make -C docs html -``` - -Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. - -## Architecture - -### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) - -All sparse attention algorithms inherit from `vFlow` and implement three methods: - -- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. -- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. -- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. - -Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. - -### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) - -Operators (`vOp` subclasses) run in two modes: -- **Profile mode**: Pre-compute output shapes and allocate buffers -- **Execute mode**: Perform actual GPU computation - -Operators are split into two parallel hierarchies: -- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load -- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup - -Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. - -### Tensor Format (`vortex_torch/abs/tensor.py`) - -`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. - -### Context System (`vortex_torch/abs/context_base.py`) - -`ContextBase` carries per-step runtime state. Specialized as: -- `Indexer.Context`: Page layout, head config, hardware info -- `Cache.Context`: Page size, total pages, model info - -### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) - -- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) -- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation -- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds - -### SGLang Integration - -Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. - -## Key Conventions - -- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` -- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) -- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` -- **Branch**: Main development is on `v1` - -## Workflow Orchestration - -### 1. Plan Node Default -- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) -- If something goes sideways, STOP and re-plan immediately - don't keep pushing -- Use plan mode for verification steps, not just building -- Write detailed specs upfront to reduce ambiguity - -### 2. Subagent Strategy -- Use subagents liberally to keep main context window clean -- Offload research, exploration, and parallel analysis to subagents -- For complex problems, throw more compute at it via subagents -- One tack per subagent for focused execution - -### 3. Self-Improvement Loop -- After ANY correction from the user: update `tasks/lessons.md` with the pattern -- Write rules for yourself that prevent the same mistake -- Ruthlessly iterate on these lessons until mistake rate drops -- Review lessons at session start for relevant project - -### 4. Verification Before Done -- Never mark a task complete without proving it works -- Diff behavior between main and your changes when relevant -- Ask yourself: "Would a staff engineer approve this?" -- Run tests, check logs, demonstrate correctness - -### 5. Demand Elegance (Balanced) -- For non-trivial changes: pause and ask "is there a more elegant way?" -- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" -- Skip this for simple, obvious fixes - don't over-engineer -- Challenge your own work before presenting it - -### 6. Autonomous Bug Fixing -- When given a bug report: just fix it. Don't ask for hand-holding -- Point at logs, errors, failing tests - then resolve them -- Zero context switching required from the user -- Go fix failing CI tests without being told how - -## Task Management - -1. **Plan First**: Write plan to `tasks/todo.md` with checkable items -2. **Verify Plan**: Check in before starting implementation -3. **Track Progress**: Mark items complete as you go -4. **Explain Changes**: High-level summary at each step -5. **Document Results**: Add review section to `tasks/todo.md` -6. **Capture Lessons**: Update `tasks/lessons.md` after corrections - -## Core Principles - -- **Simplicity First**: Make every change as simple as possible. Impact minimal code. -- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. -- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index f418598..9958b7e 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -113,13 +113,13 @@ def verify_algos( } ) # --- Per-question debug output --- - print(f"[Q{len(results):03d}] score={float(result):.1f} " - f"tokens={item['meta_info']['completion_tokens']} " - f"latency={item['meta_info']['e2e_latency']:.2f}s " - f"gold={golds[0]}") - print(f" question: {data['question'][:120]}...") - print(f" prediction: {predictions[:200]}...") - print() + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() total_accuracy = 0.0 diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh new file mode 100755 index 0000000..7f266e5 --- /dev/null +++ b/examples/verify_algo_fp8.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=3 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype fp8_e4m3 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_int8.sh similarity index 100% rename from examples/verify_algo_quant.sh rename to examples/verify_algo_int8.sh diff --git a/third_party/sglang b/third_party/sglang index 9672e9a..7105719 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 +Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b32d8bc..8c4d0e0 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c..dd1bd02 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,21 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) + "fp8_type", + "kv_scale", ) @@ -36,7 +40,11 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2..5800458 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,10 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) + fp8_type = getattr(ctx, 'fp8_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index a18067e..009e728 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,10 +1,11 @@ -from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher from .paged_decode_int8 import paged_decode_int8 from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", "dequant_paged_int8_to_bf16_inplace", diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e08..9670acd 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,16 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. +# FP8_TYPE == 0 -> bf16 pointer, load normally +# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale +# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale +# All paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +22,11 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +41,15 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -40,7 +60,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -55,7 +75,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -71,11 +91,13 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +107,9 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -97,11 +121,13 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +137,9 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -119,84 +147,67 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, # rows per token-page - x_D1: tl.constexpr, # cols per token-page + x_D0: tl.constexpr, + x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + REDUCE_TYPE: tl.constexpr, + DIM: tl.constexpr, + FP8_TYPE: tl.constexpr, + scale, ): - - # Program IDs: - # pid0 = token index (0 .. num_tokens-1) - # pid1 = head index (0 .. NUM_KV_HEAD-1) + token_id = tl.program_id(0) head_id = tl.program_id(1) - # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) - # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return - # Output page index: - # Logical page = token_position // PAGE_SIZE - # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id - - # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). - # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - # Build 2D indices within a page (row-major addressing). - rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block for this (token_id, head_id). - # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # Reduction: if DIM == 1: - # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). - if REDUCE_TYPE == 0: # Mean - # NOTE: precision-sensitive workloads may want fp32 accumulation: - # s = tl.sum(page_block.to(tl.float32), axis=0) - # reduce_vec = (s / x_D0).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). - if REDUCE_TYPE == 0: # Mean - # s = tl.sum(page_block.to(tl.float32), axis=1) - # reduce_vec = (s / x_D1).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - def reduce_rp( @@ -206,11 +217,13 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +233,9 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -232,11 +247,13 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,92 +263,76 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per page -x_D1: tl.constexpr, # cols per page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - - Behavior: - - token_id comes from pid0; head_id comes from pid1. - - Read loc[token_id] to get absolute position; only proceed at page end. - - Map token -> page via page_idx = (token_position // PAGE_SIZE). - - Read the whole page for this (page_idx, head_id), do reduction, - then write a single vector to output at (token_id, head_id, :). - """ - - # --- Program IDs --- - token_id = tl.program_id(0) # [0 .. num_tokens-1] - head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] - - # --- Trigger only at end-of-page token --- + + token_id = tl.program_id(0) + head_id = tl.program_id(1) + token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # --- Page indexing for x (page-major) --- - # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id - # Base element offset into x for this (page_id, head_id) - # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - # 2D row-major addressing within the page - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # --- Reduction & write-out --- if DIM == 1: - # Reduce over rows (axis=0) -> per-column vector, length = x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -344,11 +345,13 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +361,11 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +374,13 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,72 +390,68 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per token-page -x_D1: tl.constexpr, # cols per token-page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. - """ - - - # program ids - token_id = tl.program_id(0) # 0..num_tokens-1 - head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + token_id = tl.program_id(0) + head_id = tl.program_id(1) - # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed - # ---- reduce ---- + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr).to(tl.float32) + if DIM == 1: - # over rows -> axis=0 -> vector len x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast to fp32 before sum. + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) - # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - # DIM == 2: over cols -> axis=1 -> vector len x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -456,7 +459,6 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) - def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -464,11 +466,13 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +482,11 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +496,13 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +512,7 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 2a2c785..6b289df 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -131,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -148,3 +148,96 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + From f25fb13e6075111a9b05aedb060a3e7bd346cebd Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 1 Mar 2026 08:02:49 +0000 Subject: [PATCH 4/7] update on parameters for reduce_pp_kernel with quantization --- setup.py | 2 +- vortex_torch/cache/context.py | 12 +- vortex_torch/cache/reduce.py | 7 +- .../cache/triton_kernels/reduce_impl.py | 305 ++++++++++++------ 4 files changed, 224 insertions(+), 102 deletions(-) diff --git a/setup.py b/setup.py index 6efeebe..f35ddae 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', ], include_dirs=['csrc'], extra_compile_args={ diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index dd1bd02..0e7171c 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -22,9 +22,11 @@ class Context(ContextBase): "_aux_total_flops", - # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) - "fp8_type", + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + "quant_type", "kv_scale", + "kv_scale_ptr", ) @@ -41,10 +43,12 @@ def __init__(self) -> None: object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": object.__setattr__(self, name, Mode.profile) - elif name == "fp8_type": - object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) elif name == "kv_scale": object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 5800458..eb94795 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,10 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) - fp8_type = getattr(ctx, 'fp8_type', 0) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) scale = getattr(ctx, 'kv_scale', 1.0) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9670acd..0146af7 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -6,11 +6,12 @@ # --------------------------------------------------------------------------- -# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. -# FP8_TYPE == 0 -> bf16 pointer, load normally -# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale -# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale -# All paths return a float32 tensor ready for reduction. +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. # --------------------------------------------------------------------------- @@ -23,8 +24,9 @@ def reduce_pp_kernel( PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): token_id = tl.program_id(0) @@ -42,14 +44,21 @@ def reduce_pp_kernel( cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -60,7 +69,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -75,7 +84,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -91,8 +100,9 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -108,8 +118,9 @@ def reduce_pp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -121,8 +132,9 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -138,8 +150,9 @@ def _reduce_pp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -147,69 +160,102 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, - x_D1: tl.constexpr, + x_D0: tl.constexpr, # rows per token-page + x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, - DIM: tl.constexpr, - FP8_TYPE: tl.constexpr, - scale, + REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + # Program IDs: + # pid0 = token index (0 .. num_tokens-1) + # pid1 = head index (0 .. NUM_KV_HEAD-1) token_id = tl.program_id(0) head_id = tl.program_id(1) + # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) + # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return + # Output page index: + # Logical page = token_position // PAGE_SIZE + # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + + # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). + # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # Build 2D indices within a page (row-major addressing). + rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block for this (token_id, head_id). + # Assumes the page is full; add masks here if you have partial tiles. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # Reduction: if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). + if REDUCE_TYPE == 0: # Mean + # NOTE: precision-sensitive workloads may want fp32 accumulation: + # s = tl.sum(page_block.to(tl.float32), axis=0) + # reduce_vec = (s / x_D0).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). + if REDUCE_TYPE == 0: # Mean + # s = tl.sum(page_block.to(tl.float32), axis=1) + # reduce_vec = (s / x_D1).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) + def reduce_rp( x: torch.Tensor, output: torch.Tensor, @@ -217,8 +263,9 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -234,8 +281,9 @@ def reduce_rp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -247,8 +295,9 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -264,75 +313,110 @@ def _reduce_rp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per page +x_D1: tl.constexpr, # cols per page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - - token_id = tl.program_id(0) - head_id = tl.program_id(1) - + """ + Layouts: + x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Behavior: + - token_id comes from pid0; head_id comes from pid1. + - Read loc[token_id] to get absolute position; only proceed at page end. + - Map token -> page via page_idx = (token_position // PAGE_SIZE). + - Read the whole page for this (page_idx, head_id), do reduction, + then write a single vector to output at (token_id, head_id, :). + """ + + # --- Program IDs --- + token_id = tl.program_id(0) # [0 .. num_tokens-1] + head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] + + # --- Trigger only at end-of-page token --- token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # --- Page indexing for x (page-major) --- + # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id + # Base element offset into x for this (page_id, head_id) + # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # 2D row-major addressing within the page + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block. Assumes full tiles; add masks if needed. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # --- Reduction & write-out --- if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> per-column vector, length = x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 + if REDUCE_TYPE == 0: # Mean reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -345,8 +429,9 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -362,8 +447,9 @@ def reduce_pr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) def _reduce_pr( @@ -374,8 +460,9 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -391,67 +478,92 @@ def _reduce_pr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per token-page +x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + """ + Layouts: + x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. + """ - token_id = tl.program_id(0) - head_id = tl.program_id(1) + # program ids + token_id = tl.program_id(0) # 0..num_tokens-1 + head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + + # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_blk = tl.load(src_ptr).to(tl.float32) + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + # ---- reduce ---- if DIM == 1: - if REDUCE_TYPE == 0: + # over rows -> axis=0 -> vector len x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast to fp32 before sum. vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) + # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: over cols -> axis=1 -> vector len x_D0 + if REDUCE_TYPE == 0: # Mean vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -459,6 +571,7 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) + def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -466,8 +579,9 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -483,8 +597,9 @@ def reduce_rr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -496,8 +611,9 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -513,6 +629,7 @@ def _reduce_rr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) From b9eb71786eb8dd12150f02b4dcd15b053328b931 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 2 Mar 2026 06:33:57 +0000 Subject: [PATCH 5/7] adapt topk kernel from sglang to vortex --- csrc/topk.cu | 1029 +++++++++++++++++++++++++++------- examples/verify_algo.sh | 2 +- examples/verify_algo_fp8.sh | 1 - examples/verify_algo_int8.sh | 1 - 4 files changed, 827 insertions(+), 206 deletions(-) diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747e..8a48aad 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,203 +1,826 @@ -#include "register.h" -#include - - -template -__global__ void TopKOutput_F32_Kernel( -const float* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const float* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); - __syncthreads(); - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - -template -__global__ void TopKOutput_BF16_Kernel( -const __nv_bfloat16* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const __nv_bfloat16* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); - - __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); - - #pragma unroll - for (int i = 0; i < ITEM_PER_THREAD; ++i){ - key[i] = __bfloat162float(key_bf16[i]); - } - __syncthreads(); - - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - - -void topk_output( -const at::Tensor& x, -const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, -const int64_t eff_batch_size, -const int64_t topk_val, -const int64_t reserved_bos, -const int64_t reserved_eos, -const int64_t max_num_pages -){ - - - dim3 nblks(eff_batch_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (max_num_pages <= 128){ - TopKOutput_BF16_Kernel<128, 1><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 256){ - TopKOutput_BF16_Kernel<128, 2><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 512){ - TopKOutput_BF16_Kernel<128, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 1024){ - TopKOutput_BF16_Kernel<256, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 2048){ - TopKOutput_BF16_Kernel<256, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 4096){ - TopKOutput_BF16_Kernel<512, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else { - TORCH_CHECK(false); - } - -} +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + constexpr int TopK = 2048; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex integration: BOS/EOS-aware segmented TopK with index remapping + // ====================================================================== + + template + __device__ __forceinline__ float vortex_to_float(T x); + + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + constexpr int VORTEX_MAX_TOPK = 2048; + + // Templated version of fast_topk_cuda_tl: + // - ScoreT: float or __nv_bfloat16 + // - target_k: runtime parameter (replaces compile-time TopK) + template + __device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) + { + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Wrapper kernel: one CUDA block per batch*head segment + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) + { + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex host entry point — same interface as topk_output in topk.cu + // ====================================================================== + void topk_output( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); + } \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index d80f09a..7487708 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=1 +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index 7f266e5..fd85dad 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=3 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index 4cf1366..e57c63f 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=2 sparse_algos=( "block_sparse_attention" From ede862425998eaa1e5a3b449dec274798d0ffb1c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:22:43 +0000 Subject: [PATCH 6/7] add parameter to switch between two topk kernels (naive or sglang) --- csrc/register.cc | 1 + csrc/register.h | 12 + csrc/topk.cu | 1029 ++++++--------------------- examples/verify_algo.py | 15 +- examples/verify_algo_fp8.sh | 1 + examples/verify_algo_int8.sh | 1 + vortex_torch/indexer/context.py | 4 +- vortex_torch/indexer/output_func.py | 58 +- 8 files changed, 276 insertions(+), 845 deletions(-) diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb..532fcdf 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed..b81168b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,18 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 8a48aad..3aa49b9 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,826 +1,203 @@ -/** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access - */ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - - namespace { - - constexpr int TopK = 2048; - constexpr int kThreadsPerBlock = 1024; - - #ifdef USE_ROCM - // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a - // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. - #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES - constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); - #else - constexpr size_t kSmem = 48 * 1024; // bytes - #endif - #else - // Reduced from 128KB to 32KB to improve occupancy. - // Each radix pass needs at most ~TopK candidates in the threshold bin, - // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. - constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) - #endif - - struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; - }; - - // when length <= TopK, we can directly write the indices - __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } - } - - __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); - } - - __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - } - - __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } - } - - auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; - } - - template - void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - #ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex integration: BOS/EOS-aware segmented TopK with index remapping - // ====================================================================== - - template - __device__ __forceinline__ float vortex_to_float(T x); - - template <> - __device__ __forceinline__ float vortex_to_float(float x) { return x; } - - template <> - __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); - } - - constexpr int VORTEX_MAX_TOPK = 2048; - - // Templated version of fast_topk_cuda_tl: - // - ScoreT: float or __nv_bfloat16 - // - target_k: runtime parameter (replaces compile-time TopK) - template - __device__ void fast_topk_vortex( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int row_start, - int length, - int target_k) - { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; - - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // Stage 1: 8-bit coarse histogram - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&vh_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - convert_to_uint8(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - // Wrapper kernel: one CUDA block per batch*head segment - template - __global__ __launch_bounds__(kThreadsPerBlock) - void TopKOutput_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos) - { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } - } - - } // namespace - - #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - - void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex host entry point — same interface as topk_output in topk.cu - // ====================================================================== - void topk_output( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages) - { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); - } \ No newline at end of file +#include "register.h" +#include + + +template +__global__ void TopKOutput_F32_Kernel( +const float* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const float* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); + __syncthreads(); + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + +template +__global__ void TopKOutput_BF16_Kernel( +const __nv_bfloat16* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const __nv_bfloat16* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); + + __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); + + #pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; ++i){ + key[i] = __bfloat162float(key_bf16[i]); + } + __syncthreads(); + + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + + +void topk_output( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +){ + + + dim3 nblks(eff_batch_size); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (max_num_pages <= 128){ + TopKOutput_BF16_Kernel<128, 1><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 256){ + TopKOutput_BF16_Kernel<128, 2><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 512){ + TopKOutput_BF16_Kernel<128, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 1024){ + TopKOutput_BF16_Kernel<256, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 2048){ + TopKOutput_BF16_Kernel<256, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 4096){ + TopKOutput_BF16_Kernel<512, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else { + TORCH_CHECK(false); + } + +} \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 9958b7e..1187aca 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -56,12 +56,13 @@ def verify_algos( sparse_attention: bool = True, mem: float = 0.8, kv_cache_dtype: str = "auto", +topk_type: str = "naive", ): - llm = sgl.Engine(model_path=model_name, + llm = sgl.Engine(model_path=model_name, disable_cuda_graph=False, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -72,6 +73,7 @@ def verify_algos( vortex_max_seq_lens=12288, mem_fraction_static=mem, kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, ) with open("amc23.jsonl", "r", encoding="utf-8") as f: @@ -221,6 +223,14 @@ def parse_args(): choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], help='KV cache dtype (default: "auto").', ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + ) return parser.parse_args() if __name__ == "__main__": @@ -235,6 +245,7 @@ def parse_args(): sparse_attention=not(args.full_attention), mem=args.mem, kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, ) print(summary) diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index fd85dad..c0b8814 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index e57c63f..bf24c2d 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586..d6da9c1 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,7 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +68,7 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive" or "sglang". # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,6 +145,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b..f7d0d9c 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang from .context import Context from ..abs import vTensor, FORMAT @@ -75,13 +75,17 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +156,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +230,32 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) return o From edbf7899b34bb70724839b466c208750c2f6df94 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:30:20 +0000 Subject: [PATCH 7/7] add parameter to switch between two topk kernels (naive or sglang) --- examples/verify_algo.sh | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 7487708..73ac2f4 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -19,6 +19,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-val 30 \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/setup.py b/setup.py index f35ddae..99c6529 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ 'csrc/register.cc', 'csrc/utils_sglang.cu', 'csrc/topk.cu', + 'csrc/topk_sglang.cu', ], include_dirs=['csrc'], extra_compile_args={