Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){
m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose);
m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose);
m.def("topk_output", &topk_output);
m.def("topk_output_sglang", &topk_output_sglang);
m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3);
m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3);
m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3);
Expand Down
12 changes: 12 additions & 0 deletions csrc/register.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ const int64_t reserved_eos,
const int64_t max_seq_lengths
);

void topk_output_sglang(
const at::Tensor& x,
const at::Tensor& dense_kv_indptr,
const at::Tensor& sparse_kv_indptr,
const at::Tensor& dense_kv_indices,
at::Tensor& sparse_kv_indices,
const int64_t eff_batch_size,
const int64_t topk_val,
const int64_t reserved_bos,
const int64_t reserved_eos,
const int64_t max_seq_lengths
);

void sglang_plan_decode_fa3(
const at::Tensor& cached_seq_lens,
Expand Down
2 changes: 1 addition & 1 deletion csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ const int64_t max_num_pages
TORCH_CHECK(false);
}

}
}
42 changes: 36 additions & 6 deletions examples/verify_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def verify_algos(
vortex_module_name: str = "gqa_block_sparse_attention",
model_name: str = "Qwen/Qwen3-1.7B",
sparse_attention: bool = True,
mem: float = 0.8
mem: float = 0.8,
kv_cache_dtype: str = "auto",
topk_type: str = "naive",
):

llm = sgl.Engine(model_path=model_name,
llm = sgl.Engine(model_path=model_name,
disable_cuda_graph=False,
page_size=page_size,
vortex_topk_val=topk_val,
vortex_topk_val=topk_val,
disable_overlap_schedule=True,
attention_backend="flashinfer",
enable_vortex_sparsity=sparse_attention,
Expand All @@ -69,10 +71,12 @@ def verify_algos(
vortex_layers_skip=list(range(1)),
vortex_module_name=vortex_module_name,
vortex_max_seq_lens=12288,
mem_fraction_static=mem
mem_fraction_static=mem,
kv_cache_dtype=kv_cache_dtype,
vortex_topk_type=topk_type,
)

with open("examples/amc23.jsonl", "r", encoding="utf-8") as f:
with open("amc23.jsonl", "r", encoding="utf-8") as f:
requests = [json.loads(line) for line in f]

requests = requests * trials
Expand Down Expand Up @@ -110,6 +114,14 @@ def verify_algos(
"num_tokens": item["meta_info"]["completion_tokens"]
}
)
# --- Per-question debug output ---
# print(f"[Q{len(results):03d}] score={float(result):.1f} "
# f"tokens={item['meta_info']['completion_tokens']} "
# f"latency={item['meta_info']['e2e_latency']:.2f}s "
# f"gold={golds[0]}")
# print(f" question: {data['question'][:120]}...")
# print(f" prediction: {predictions[:200]}...")
# print()


total_accuracy = 0.0
Expand Down Expand Up @@ -203,6 +215,22 @@ def parse_args():
default=0.8,
help="memory fraction in sglang",
)

parser.add_argument(
"--kv-cache-dtype",
type=str,
default="auto",
choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"],
help='KV cache dtype (default: "auto").',
)

parser.add_argument(
"--topk-type",
type=str,
default="naive",
choices=["naive", "sglang"],
help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").',
)
return parser.parse_args()

if __name__ == "__main__":
Expand All @@ -215,7 +243,9 @@ def parse_args():
vortex_module_name=args.vortex_module_name,
model_name=args.model_name,
sparse_attention=not(args.full_attention),
mem=args.mem
mem=args.mem,
kv_cache_dtype=args.kv_cache_dtype,
topk_type=args.topk_type,
)
print(summary)

Expand Down
28 changes: 18 additions & 10 deletions examples/verify_algo.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#!/usr/bin/env bash
set -e
# export CUDA_VISIBLE_DEVICES=0

sparse_algos=(

"block_sparse_attention"
)

for algo in "${sparse_algos[@]}"; do
echo ">>> Running verify_algo.py with --vortex-module-name ${algo}"
python examples/verify_algo.py \
--trials 8 \
--topk-val 30 \
--vortex-module-name "${algo}" \
--model-name Qwen/Qwen3-1.7B \
--mem 0.7
done
RESULTS_DIR="results"
mkdir -p "${RESULTS_DIR}"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

for algo in "${sparse_algos[@]}"; do
OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log"
echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16"
echo ">>> Saving results to ${OUTFILE}"
{ time python verify_algo.py \
--trials 8 \
--topk-val 30 \
--vortex-module-name "${algo}" \
--model-name Qwen/Qwen3-1.7B \
--topk-type sglang \
--mem 0.7 ; } \
2>&1 | tee "${OUTFILE}"
done
25 changes: 25 additions & 0 deletions examples/verify_algo_fp8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -e
# export CUDA_VISIBLE_DEVICES=0

sparse_algos=(
"block_sparse_attention"
)

RESULTS_DIR="results"
mkdir -p "${RESULTS_DIR}"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

for algo in "${sparse_algos[@]}"; do
OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log"
echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3"
echo ">>> Saving results to ${OUTFILE}"
{ time python verify_algo.py \
--trials 8 \
--topk-val 30 \
--vortex-module-name "${algo}" \
--model-name Qwen/Qwen3-1.7B \
--kv-cache-dtype fp8_e4m3 \
--mem 0.7 ; } \
2>&1 | tee "${OUTFILE}"
done
25 changes: 25 additions & 0 deletions examples/verify_algo_int8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -e
# export CUDA_VISIBLE_DEVICES=0

sparse_algos=(
"block_sparse_attention"
)

RESULTS_DIR="results"
mkdir -p "${RESULTS_DIR}"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

for algo in "${sparse_algos[@]}"; do
OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log"
echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8"
echo ">>> Saving results to ${OUTFILE}"
{ time python verify_algo.py \
--trials 8 \
--topk-val 30 \
--vortex-module-name "${algo}" \
--model-name Qwen/Qwen3-1.7B \
--kv-cache-dtype int8 \
--mem 0.7 ; } \
2>&1 | tee "${OUTFILE}"
done
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
sources=[
'csrc/register.cc',
'csrc/utils_sglang.cu',
'csrc/topk.cu'
'csrc/topk.cu',
'csrc/topk_sglang.cu',
],
include_dirs=['csrc'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'-gencode=arch=compute_86,code=sm_86',
'-gencode=arch=compute_89,code=sm_89',
'-gencode=arch=compute_90,code=sm_90'
'-gencode=arch=compute_90,code=sm_90',
'-gencode=arch=compute_100a,code=sm_100a',
'-gencode=arch=compute_120,code=sm_120'
],
},
),
Expand Down
5 changes: 4 additions & 1 deletion vortex_torch/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
from .matmul import GeMM
from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul
from .elementwise_binary import Maximum, Minimum, Multiply, Add
from .triton_kernels import set_kv_buffer_launcher
from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace


__all__ = [
"set_kv_buffer_launcher",
"set_kv_buffer_int8_launcher",
"set_kv_buffer_fp8_launcher",
"dequant_paged_int8_to_bf16_inplace",
"Mean", "Max", "Min", "L2Norm",
"GeMM",
"Relu", "Silu", "Sigmoid", "Abs", "Add_Mul",
Expand Down
24 changes: 18 additions & 6 deletions vortex_torch/cache/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ class Context(ContextBase):
"""

__slots__ = ContextBase.__slots__ + (

#page infomation
"max_new_tokens_per_batch", "page_size", "total_num_pages",

#model infomation
"head_dim", "head_num",

# auxilary memory in graph
"_aux_total_bytes",

"_aux_total_flops"

"_aux_total_flops",

# Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2),
# kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor)
"quant_type",
"kv_scale",
"kv_scale_ptr",
)


Expand All @@ -36,7 +42,13 @@ def __init__(self) -> None:
elif name == "_aux_total_flops":
object.__setattr__(self, name, 0) # start from 0 flops
elif name == "mode":
object.__setattr__(self, name, Mode.profile)
object.__setattr__(self, name, Mode.profile)
elif name == "quant_type":
object.__setattr__(self, name, 0) # 0 = none (bf16 default)
elif name == "kv_scale":
object.__setattr__(self, name, 1.0) # identity scale for bf16
elif name == "kv_scale_ptr":
object.__setattr__(self, name, None) # per-token scale tensor (int8 only)
else:
object.__setattr__(self, name, UNSET)

Expand Down
7 changes: 5 additions & 2 deletions vortex_torch/cache/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,11 @@ def execute(
)
output = self.output_buffer

# Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type)
self.impl(x, output, loc, ctx, self.dim, self.reduce_type)
# Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr)
quant_type = getattr(ctx, 'quant_type', 0)
scale = getattr(ctx, 'kv_scale', 1.0)
kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None)
self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr)
return output


Expand Down
13 changes: 11 additions & 2 deletions vortex_torch/cache/triton_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from .set_kv import set_kv_buffer_launcher
from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher
from .paged_decode_int8 import paged_decode_int8
from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace

__all__ = ["set_kv_buffer_launcher"]
__all__ = [
"set_kv_buffer_launcher",
"set_kv_buffer_int8_launcher",
"set_kv_buffer_fp8_launcher",
"paged_decode_int8",
"dequant_paged_int8_to_bf16",
"dequant_paged_int8_to_bf16_inplace",
]

Loading