diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..d91d8c2851 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -114,7 +114,6 @@ def moe_align1_kernel( TOKEN_BLOCK_SIZE: tl.constexpr, NUM_STAGE: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_n = tl.arange(0, TOKEN_BLOCK_SIZE) @@ -308,7 +307,6 @@ def moe_align2_kernel( BLOCK_M: tl.constexpr, BLOCK_EXPERT: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_expert = tl.arange(0, BLOCK_EXPERT) expert_to_token_num = tl.load(experts_token_num_ptr + off_expert, mask=off_expert < expert_num, other=0) @@ -417,6 +415,7 @@ def grouped_matmul_kernel( n_block_num, # int compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, + WEIGHT_SCALE_PER_TENSOR: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, # tile sizes @@ -504,10 +503,13 @@ def grouped_matmul_kernel( a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num)[:, None] a_scale = tl.load(a_scale_ptrs, mask=token_mask[:, None], other=0.0, eviction_policy="evict_last") - b_scale = tl.load( - weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1, - eviction_policy="evict_last", - ) + if WEIGHT_SCALE_PER_TENSOR: + b_scale = tl.load(weight_scale_ptr + expert_id * weight_scale_stride0, eviction_policy="evict_last") + else: + b_scale = tl.load( + weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1, + eviction_policy="evict_last", + ) ab_scale = a_scale * b_scale if NEED_TRANS: @@ -702,7 +704,7 @@ def grouped_matmul( expert_to_token_num is tensor shape [expert_num], expert_to_token_index is tensor shape [expert_num, token_num * topk_num], expert_weights is tensor shape [expert_num, out_dim, hidden_dim] - expert_to_weights_scale is tensor shape [expert_num] or + expert_to_weights_scale is tensor shape [expert_num], [expert_num, 1], [expert_num, out_dim] or [expert_num, out_dim // block_size_, hidden_dim // block_size_k], when use_fp8_w8a8 is False, it must be None out is tensor shape [token_num * topk_num, out_dim] @@ -723,6 +725,9 @@ def grouped_matmul( if expert_to_weights_scale.ndim == 3: block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1] block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2] + weight_scale_per_tensor = ( + use_fp8_w8a8 and expert_to_weights_scale is not None and expert_to_weights_scale.numel() == expert_num + ) if run_config is None: if token_inputs.shape[0] <= expert_num: @@ -872,6 +877,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): n_block_num=triton.cdiv(n, BLOCK_SIZE_N), compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + WEIGHT_SCALE_PER_TENSOR=weight_scale_per_tensor, block_size_n=block_size_n, block_size_k=block_size_k, BLOCK_SIZE_M=BLOCK_SIZE_M, diff --git a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py index cb3a3d7316..925fed2b11 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py @@ -8,7 +8,6 @@ @triton.jit def grouped_launch(pid, m_block_num, n_block_num, group_m: tl.constexpr): - num_pid_in_group = group_m * n_block_num group_id = pid // num_pid_in_group first_pid_m = group_id * group_m @@ -45,6 +44,7 @@ def _scaled_mm_per_token( stride_cn, USE_TMA: tl.constexpr, B_IS_TRANS: tl.constexpr, + B_SCALE_IS_TENSOR: tl.constexpr, NEED_N_MASK: tl.constexpr, NEED_K_MASK: tl.constexpr, BLOCK_M: tl.constexpr, @@ -77,9 +77,11 @@ def _scaled_mm_per_token( b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) Ascale_ptrs = Ascale + offs_am - Bscale_ptrs = Bscale + offs_bn a_s = tl.load(Ascale_ptrs) - b_s = tl.load(Bscale_ptrs) + if B_SCALE_IS_TENSOR: + b_s = tl.load(Bscale) + else: + b_s = tl.load(Bscale + offs_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE) @@ -143,12 +145,13 @@ def get_test_configs(): return fp8_gemm_configs -def _get_static_key(A, B, out_dtype): +def _get_static_key(A, B, Bscale, out_dtype): M, K = A.shape _, N = B.shape return { "N": N, "K": K, + "b_scale_kind": "tensor" if Bscale.numel() == 1 else "channel", "out_dtype": str(out_dtype), } @@ -175,7 +178,8 @@ def scaled_mm_per_token( A: Matrix A with shape of [M, K]. B: Matrix B with shape of [K, N]. Ascale: per-token Quantization scale for A: [M] or [M, 1]. - Bscale: per-channel Quantization scale for B: [N] or [1, N]. + Bscale: per-channel Quantization scale for B: [N] or [1, N], + or per-tensor scale [1]. out_dtype: The data type of out. out: The output matrix with the shape of [M, N]. Returns: @@ -239,6 +243,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): out_desc = None ACC_DTYPE = tl.int32 if A.dtype == torch.int8 else tl.float32 + B_SCALE_IS_TENSOR = Bscale.numel() == 1 _scaled_mm_per_token[grid]( A=A, @@ -260,6 +265,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): stride_cn=out.stride(1), USE_TMA=support_tma, B_IS_TRANS=B_is_trans, + B_SCALE_IS_TENSOR=B_SCALE_IS_TENSOR, NEED_N_MASK=NEED_N_MASK, NEED_K_MASK=NEED_K_MASK, ACC_DTYPE=ACC_DTYPE, diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 65ec6cd145..cee0841dcc 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -21,11 +21,41 @@ "1", ] +FP8_E4M3_MAX = 448.0 + + +def _fp8_per_tensor_quant(weight: torch.Tensor, device_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + weight = weight.float().cuda(device_id) + if weight.ndim == 3: + scale = weight.abs().amax(dim=(-1, -2)) / FP8_E4M3_MAX + else: + scale = weight.abs().max() / FP8_E4M3_MAX + scale = torch.clamp(scale, min=torch.finfo(torch.float32).tiny) + scale_view = scale.reshape(-1, 1, 1) if weight.ndim == 3 else scale + qweight = _fp8_quant_with_scale(weight, scale_view) + return qweight, scale.reshape(-1) + + +def _fp8_quant_with_scale(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return (weight / scale).clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX).to(dtype=torch.float8_e4m3fn) + + +def _copy_scale_with_broadcast(dst: torch.Tensor, src: torch.Tensor) -> None: + if dst.numel() == src.numel(): + dst.copy_(src.reshape_as(dst)) + elif src.numel() == 1: + if dst.dim() == 0: + dst.copy_(src.reshape(())) + else: + dst.copy_(src.reshape(1).expand_as(dst)) + else: + raise ValueError(f"can not copy scale with shape {tuple(src.shape)} to {tuple(dst.shape)}") + class BaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() - assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." + # assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager self.cache_manager = g_cache_manager @@ -124,7 +154,6 @@ def __init__(self): self.has_weight_zero_point = False def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: - qweight, weight_scale = scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) @@ -181,6 +210,147 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register( + ["triton-fp8w8a8-pertensor", "fp8w8a8-pertensor", "triton-fp8w8a8-pt", "fp8w8a8-pt"], + platform="cuda", +) +class TritonFP8w8a8PerTensorQuantizationMethod(BaseQuantizationMethod): + def __init__(self): + super().__init__() + self.has_weight_scale = True + self.has_weight_zero_point = False + + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + if weight.ndim == 3 and output.weight_scale is not None and output.weight_scale.numel() == weight.shape[0]: + for expert_idx in range(weight.shape[0]): + qweight, weight_scale = _fp8_per_tensor_quant(weight[expert_idx], self.device_id_) + output.weight[expert_idx].copy_(qweight) + output.weight_scale[expert_idx].copy_(weight_scale.reshape(())) + return + + qweight, weight_scale = _fp8_per_tensor_quant(weight, self.device_id_) + output.weight.copy_(qweight) + _copy_scale_with_broadcast(output.weight_scale, weight_scale) + return + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: + parent_pack = getattr(weight_pack, "_fp8_pt_parent_pack", None) + if parent_pack is None: + super().load_weight(weight, weight_pack) + return + + staged_weight = weight_pack._fp8_pt_staged_weight + staged_weight.copy_(weight.to(device=staged_weight.device, dtype=staged_weight.dtype, non_blocking=True)) + loaded_index = weight_pack._fp8_pt_child_index + if hasattr(weight_pack, "_fp8_pt_expert_index"): + loaded_index = (weight_pack._fp8_pt_expert_index, loaded_index) + parent_pack._fp8_pt_staged_loaded[loaded_index] = True + self._try_finalize_deferred_weight(parent_pack) + return + + def _try_finalize_deferred_weight(self, parent_pack: WeightPack) -> bool: + if getattr(parent_pack, "_fp8_pt_finalized", False): + return True + staged_loaded = parent_pack._fp8_pt_staged_loaded + if isinstance(staged_loaded, torch.Tensor): + all_loaded = bool(staged_loaded.all().item()) + else: + all_loaded = all(staged_loaded) + if not all_loaded: + return False + + self.quantize(parent_pack._fp8_pt_staged_weight, parent_pack) + parent_pack.load_ok = [True, True, True] + parent_pack._fp8_pt_finalized = True + parent_pack._fp8_pt_staged_weight = None + for child_pack in parent_pack._fp8_pt_child_packs: + child_pack.load_ok = [True, True, True] + child_pack._fp8_pt_staged_weight = None + for expert_child_pack in getattr(child_pack, "_fp8_pt_expert_child_packs", []): + expert_child_pack.load_ok = [True, True, True] + expert_child_pack._fp8_pt_staged_weight = None + return True + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + m = input_tensor.shape[0] + n = qweight.shape[1] + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + assert bias is None, "Bias addition is not supported in triton-fp8w8a8-pertensor for now" + return fp8_scaled_mm_per_token( + x_q, + qweight, + x_scale, + weight_scale, + input_tensor.dtype, + out, + ) + + @property + def method_name(self): + return "triton-fp8w8a8-pertensor" + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + if isinstance(out_dims, int): + out_dims = [out_dims] + out_dim = sum(out_dims) + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + + weight_scale = torch.empty(expert_prefix or (1,), dtype=torch.float32, device=f"cuda:{device_id}") + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + weight_splits = torch.split(weight, out_dims, dim=-2) + mm_param_list = [WeightPack(weight=weight, weight_scale=weight_scale) for weight in weight_splits] + + if len(out_dims) > 1: + staged_weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype, device="cpu") + staged_splits = torch.split(staged_weight, out_dims, dim=-2) + mm_param._fp8_pt_staged_weight = staged_weight + if num_experts > 1: + mm_param._fp8_pt_staged_loaded = torch.zeros( + (num_experts, len(mm_param_list)), dtype=torch.bool, device="cpu" + ) + else: + mm_param._fp8_pt_staged_loaded = [False] * len(mm_param_list) + mm_param._fp8_pt_child_packs = mm_param_list + mm_param._fp8_pt_finalized = False + for idx, (child_pack, staged_split) in enumerate(zip(mm_param_list, staged_splits)): + child_pack._fp8_pt_parent_pack = mm_param + child_pack._fp8_pt_child_index = idx + child_pack._fp8_pt_staged_weight = staged_split + if num_experts > 1: + child_pack._fp8_pt_expert_child_packs = [] + child_pack._fp8_pt_get_expert = child_pack.get_expert + + def _get_deferred_expert(expert_idx, _child_pack=child_pack): + expert_child_pack = _child_pack._fp8_pt_get_expert(expert_idx) + expert_child_pack._fp8_pt_parent_pack = _child_pack._fp8_pt_parent_pack + expert_child_pack._fp8_pt_child_index = _child_pack._fp8_pt_child_index + expert_child_pack._fp8_pt_expert_index = expert_idx + expert_child_pack._fp8_pt_staged_weight = _child_pack._fp8_pt_staged_weight[expert_idx] + _child_pack._fp8_pt_expert_child_packs.append(expert_child_pack) + return expert_child_pack + + child_pack.get_expert = _get_deferred_expert + return mm_param, mm_param_list + + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"], platform="cuda") class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): def __init__(self): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..3e8da61ab3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -603,7 +603,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--quant_type", type=str, default="none", - help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 + help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | triton-fp8w8a8-pertensor | vllm-fp8w8a8-b128 | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""",