diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb..532fcdf 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed..b81168b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,18 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747e..3aa49b9 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -200,4 +200,4 @@ const int64_t max_num_pages TORCH_CHECK(false); } -} +} \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81..1187aca 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -54,13 +54,15 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 +mem: float = 0.8, +kv_cache_dtype: str = "auto", +topk_type: str = "naive", ): - llm = sgl.Engine(model_path=model_name, + llm = sgl.Engine(model_path=model_name, disable_cuda_graph=False, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -69,10 +71,12 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, ) - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: + with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] requests = requests * trials @@ -110,6 +114,14 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) + # --- Per-question debug output --- + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() total_accuracy = 0.0 @@ -203,6 +215,22 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + ) return parser.parse_args() if __name__ == "__main__": @@ -215,7 +243,9 @@ def parse_args(): vortex_module_name=args.vortex_module_name, model_name=args.model_name, sparse_attention=not(args.full_attention), - mem=args.mem + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5e..73ac2f4 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,25 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh new file mode 100755 index 0000000..c0b8814 --- /dev/null +++ b/examples/verify_algo_fp8.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +# export CUDA_VISIBLE_DEVICES=0 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype fp8_e4m3 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh new file mode 100644 index 0000000..bf24c2d --- /dev/null +++ b/examples/verify_algo_int8.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +# export CUDA_VISIBLE_DEVICES=0 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/setup.py b/setup.py index e272326..99c6529 100644 --- a/setup.py +++ b/setup.py @@ -16,15 +16,19 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', + 'csrc/topk_sglang.cu', ], include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0f..7105719 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa46..8c4d0e0 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,14 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c..0e7171c 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,23 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + "quant_type", + "kv_scale", + "kv_scale_ptr", ) @@ -36,7 +42,13 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2..eb94795 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfc..009e728 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,13 @@ -from .set_kv import set_kv_buffer_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher +from .paged_decode_int8 import paged_decode_int8 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace -__all__ = ["set_kv_buffer_launcher"] +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "paged_decode_int8", + "dequant_paged_int8_to_bf16", + "dequant_paged_int8_to_bf16_inplace", +] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py new file mode 100644 index 0000000..4f33cd4 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -0,0 +1,363 @@ +""" +Custom Triton paged decode attention kernel for int8 KV cache. + +Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, +and computes standard multi-head attention with online softmax. + +Adapted from SGLang's decode_attention.py for use with Vortex's paged layout +where each KV head is treated as a separate "batch" entry. +""" + +import torch +import triton +import triton.language as tl + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_int8_stage1( + Q, # [batch, num_qo_heads, head_dim] bf16 + K_Buffer, # int8 paged: flat + V_Buffer, # int8 paged: flat + K_Scale_Buffer, # fp16: flat (one scale per token slot) + V_Scale_Buffer, # fp16: flat + sm_scale, + kv_indptr, # [batch + 1] int32, page-level + kv_indices, # page indices + last_page_len, # [batch] int32, tokens valid in last page + Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] + Att_Lse, # [batch, num_qo_heads, max_kv_splits] + num_kv_splits, # [batch] int32 + stride_qbs, + stride_qh, + stride_buf_kbs, # stride per token in K_Buffer (= head_dim) + stride_buf_vbs, # stride per token in V_Buffer (= head_dim) + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """ + Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. + + kv_indptr is page-level. Total tokens for batch i: + (num_pages - 1) * PAGE_SIZE + last_page_len[i] + """ + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + # Correct token count accounting for partial last page + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + # Convert token offsets to page_id + in-page offset + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + + # Load page indices from kv_indices (physical page IDs) + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, + other=0, + ) + + # Flat token location: physical_page * PAGE_SIZE + in_page_offset + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load int8 K and dequantize + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], + other=0, + ).to(tl.float32) + + k_scale = tl.load( + K_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ).to(tl.float32) + k = k_int8 * k_scale[:, None] + + # Compute QK + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load int8 V and dequantize + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], + other=0, + ).to(tl.float32) + + v_scale = tl.load( + V_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ).to(tl.float32) + v = v_int8 * v_scale[:, None] + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=mask_dv, + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +@triton.jit +def _fwd_kernel_int8_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode_int8( + q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 + k_buffer: torch.Tensor, # int8 paged K cache + v_buffer: torch.Tensor, # int8 paged V cache + k_scale_buffer: torch.Tensor, # fp16 scale for K + v_scale_buffer: torch.Tensor, # fp16 scale for V + o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output + kv_indptr: torch.Tensor, # [batch + 1] int32, page-level + kv_indices: torch.Tensor, # page indices + last_page_len: torch.Tensor, # [batch] int32 + num_kv_splits: torch.Tensor, # [batch] int32 + max_kv_splits: int, + sm_scale: float, + page_size: int, + logit_cap: float = 0.0, + att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] + att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] +): + """ + Paged decode attention with int8 KV cache and inline dequantization. + + kv_indptr is page-level. last_page_len specifies valid tokens in the last page + for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 64 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + + num_warps = 4 if kv_group_num == 1 else 2 + + # Use pre-allocated buffers if provided, otherwise allocate + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + else: + att_lse = att_lse[:batch] + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_int8_stage1[grid_stage1]( + q, + k_buffer, + v_buffer, + k_scale_buffer, + v_scale_buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + att_out, + att_lse, + num_kv_splits, + q.stride(0), + q.stride(1), + stride_buf_kbs, + stride_buf_vbs, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + PAGE_SIZE=page_size, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_int8_stage2[grid_stage2]( + att_out, + att_lse, + o, + kv_indptr, + last_page_len, + num_kv_splits, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + o.stride(0), + o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py new file mode 100644 index 0000000..8927983 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -0,0 +1,168 @@ +""" +OOM-safe bf16 fallback for int8 KV-cache prefill. + +Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, +this module dequantizes only the accessed KV pages into a compact temporary +bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. + +This avoids dequantizing the entire global cache buffer. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dequant_pages_kernel( + src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat + src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat + dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat + page_indices, # int32 [num_accessed_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16 compact buffer.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) + + # Scale: global_page_id * PAGE_SIZE + token_idx + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset).to(tl.float32) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16( + src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] + src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] + page_indices: torch.Tensor, # int32 [num_accessed_pages] + page_size: int, + head_dim: int, + out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] +) -> torch.Tensor: + """ + Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + + If `out` is provided, writes into it (must have room for num_accessed_pages). + Otherwise allocates a new buffer. + + Returns: + bf16 tensor of shape [num_accessed_pages, page_size, head_dim] + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + if out is not None: + return out[:0] + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) + + if out is not None: + dst_bf16 = out[:num_accessed_pages] + else: + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) + + return dst_bf16 + + +@triton.jit +def _dequant_pages_inplace_kernel( + src_int8, # int8 paged buffer flat + src_scale, # scale buffer flat (one scale per token slot) + dst_bf16, # bf16 destination buffer (same page layout as src) + page_indices, # int32 [num_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source and destination use the SAME offset (in-place layout) + offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) + + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset).to(tl.float32) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Write to the SAME page position in dst (not compacted) + tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16_inplace( + src_int8: torch.Tensor, # int8 paged cache (flat) + src_scale: torch.Tensor, # fp16 scale buffer (flat) + dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) + page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant + page_size: int, + head_dim: int, +) -> None: + """ + Dequantize selected pages from int8 cache to bf16 IN-PLACE. + + Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), + this writes to the SAME page positions in dst_bf16, preserving the paged layout. + Used to populate the bf16 working buffer for forward_cache (centroid computation). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_inplace_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e08..0146af7 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,17 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +23,12 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +43,22 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -71,11 +100,14 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +117,10 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -97,11 +132,14 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +149,10 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -124,9 +165,12 @@ def reduce_rp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - + # Program IDs: # pid0 = token index (0 .. num_tokens-1) # pid1 = head index (0 .. NUM_KV_HEAD-1) @@ -156,7 +200,20 @@ def reduce_rp_kernel( # Load the full page block for this (token_id, head_id). # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # Reduction: if DIM == 1: @@ -196,7 +253,7 @@ def reduce_rp_kernel( # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - + def reduce_rp( @@ -206,11 +263,14 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +280,10 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -232,11 +295,14 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,7 +312,10 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -258,7 +327,10 @@ def reduce_pr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -297,7 +369,20 @@ def reduce_pr_kernel( src_ptr = x + x_offset + rows * x_D1 + cols # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # --- Reduction & write-out --- if DIM == 1: @@ -344,11 +429,14 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +446,12 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +460,14 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,7 +477,10 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -395,7 +492,10 @@ def reduce_rr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -420,7 +520,22 @@ def reduce_rr_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed # ---- reduce ---- if DIM == 1: @@ -464,11 +579,14 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +596,12 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +611,14 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +628,8 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab..6b289df 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales (fp16): shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, @@ -44,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -61,3 +148,96 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586..d6da9c1 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,7 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +68,7 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive" or "sglang". # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,6 +145,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b..f7d0d9c 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang from .context import Context from ..abs import vTensor, FORMAT @@ -75,13 +75,17 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +156,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +230,32 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) return o