Vortex quantization and topk kernel adaption#1
Open
zxr-creator wants to merge 7 commits intoInfini-AI-Lab:v1from
Open
Vortex quantization and topk kernel adaption#1zxr-creator wants to merge 7 commits intoInfini-AI-Lab:v1from
zxr-creator wants to merge 7 commits intoInfini-AI-Lab:v1from
Conversation
Key changes: 1. Memory Pool (`vtx_graph_memory_pool.py`): - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations. - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout. - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical. 2. Quantize-on-Write (`set_kv.py`): - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`). - Wired the new launcher into the cache update flow. 3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`): - Bypassed FlashInfer for INT8 decoding. - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers. - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`). 4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`): - Implemented an OOM-safe `bf16` fallback for prefill. - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer. - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add INT8 Quantization Support for Vortex
This PR adds INT8 quantization support to the Vortex sparse attention framework to reduce memory usage and enable low-precision execution.
Main Changes
Implement INT8 quantization with adjustments to improve memory utilization.
Add preliminary FP8 quantization support.
Update reduce_pp_kernel parameters to support quantized data.
Adapt the Top-K kernel from SGLang for Vortex.
Add a runtime parameter to switch between two Top-K kernels (naive and sglang).
Add RTX PRO 6000 compatibility and fix several Vortex kernel issues.