Skip to content

Vortex quantization and topk kernel adaption#1

Open
zxr-creator wants to merge 7 commits intoInfini-AI-Lab:v1from
zxr-creator:v1
Open

Vortex quantization and topk kernel adaption#1
zxr-creator wants to merge 7 commits intoInfini-AI-Lab:v1from
zxr-creator:v1

Conversation

@zxr-creator
Copy link

Add INT8 Quantization Support for Vortex

This PR adds INT8 quantization support to the Vortex sparse attention framework to reduce memory usage and enable low-precision execution.

Main Changes

Implement INT8 quantization with adjustments to improve memory utilization.

Add preliminary FP8 quantization support.

Update reduce_pp_kernel parameters to support quantized data.

Adapt the Top-K kernel from SGLang for Vortex.

Add a runtime parameter to switch between two Top-K kernels (naive and sglang).

Add RTX PRO 6000 compatibility and fix several Vortex kernel issues.

Key changes:
1. Memory Pool (`vtx_graph_memory_pool.py`):
   - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations.
   - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout.
   - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical.

2. Quantize-on-Write (`set_kv.py`):
   - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`).
   - Wired the new launcher into the cache update flow.

3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`):
   - Bypassed FlashInfer for INT8 decoding.
   - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers.
   - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`).

4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`):
   - Implemented an OOM-safe `bf16` fallback for prefill.
   - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer.
   - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant