feat(quant): support tensorwise fp8 w8a8 (--quant_type fp8w8a8-pt)#1366
feat(quant): support tensorwise fp8 w8a8 (--quant_type fp8w8a8-pt)#1366sufubao wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new quantization method, FP8w8a8PerTensorQuantizationMethod (vllm-fp8w8a8-pt), which implements tensorwise (per-tensor) dynamic FP8 W8A8 quantization. It also updates the command-line interface help text to document this new option. The review feedback suggests replacing .view(-1) with .flatten() when extracting the scalar from weight_scale to safely handle potential 0-dimensional tensors and avoid compatibility errors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False | ||
| ) | ||
| output.weight.copy_(qweight) | ||
| output.weight_scale.fill_(weight_scale.view(-1)[0]) |
There was a problem hiding this comment.
Using .view(-1) on a 0-dimensional (scalar) tensor in PyTorch can raise a RuntimeError (e.g., view size is not compatible with input tensor size and stride). Since weight_scale is a per-tensor scale and might be returned as a 0-dimensional tensor, it is safer to use .flatten()[0] or .reshape(-1)[0] which gracefully handles 0-dimensional tensors without triggering host-device synchronization.
| output.weight_scale.fill_(weight_scale.view(-1)[0]) | |
| output.weight_scale.fill_(weight_scale.flatten()[0]) |
Add TritonFP8w8a8PerTensorQuantizationMethod (triton-fp8w8a8-pertensor/-pt) with deferred staging for fused multi-split and per-expert weights. Support per-tensor weight scale in the grouped MoE matmul (WEIGHT_SCALE_PER_TENSOR) and the per-token scaled_mm kernel (B_SCALE_IS_TENSOR).
What
Adds a tensorwise (per-tensor) FP8 W8A8 quantization method, selectable via
--quant_type fp8w8a8-pt(aliasvllm-fp8w8a8-pt).How
New
FP8w8a8PerTensorQuantizationMethodinlightllm/common/quantization/w8a8.py, subclassing the existingFP8w8a8QuantizationMethodand overriding only what differs for per-tensor:quantize:scaled_fp8_quant(..., use_per_token_if_dynamic=False)→ a single scalar weight scale, broadcast across the sub-weight's output channels.apply:scaled_fp8_quant(input, ..., use_per_token_if_dynamic=False)→ a single per-tensor activation scale, thencutlass_scaled_mm.The weight scale is kept in the inherited per-channel storage layout (scalar broadcast across channels). This lets fused weights (qkv / gate_up) each get their own correct per-tensor scale while the cutlass per-channel weight path consumes them unchanged; per-tensor activation combines fine with per-channel weight scales in cutlass.
_create_weightis reused from the parent.Also documents the new value in the
--quant_typeCLI help.Verification
fp8w8a8-ptandvllm-fp8w8a8-ptresolve to the new class inQUANTMETHODS.black --line-length=120andflake8(project config) pass.fp8w8a8).