From 0c310a53ada0335a40d43713394a3f9030ff2171 Mon Sep 17 00:00:00 2001 From: zccjjj Date: Mon, 25 May 2026 21:52:06 +0800 Subject: [PATCH] [XPU] support yiyan model w4a8C8/C16+TP4EP4/PD disaggregation+skip layer mix quant --- .../layers/backends/xpu/attention.py | 8 +- .../layers/backends/xpu/moe/fused_moe.py | 31 ++++ .../backends/xpu/quantization/kv_cache.py | 132 ++++++++++-------- fastdeploy/model_executor/layers/moe/moe.py | 18 +++ .../layers/quantization/__init__.py | 9 ++ .../model_executor/models/ernie4_5_moe.py | 1 + fastdeploy/model_executor/utils.py | 19 ++- 7 files changed, 154 insertions(+), 64 deletions(-) diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index dbc41d6dd19..87296fc59eb 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -179,8 +179,8 @@ def forward_mixed( cache_v_scale = getattr(layer, "cache_v_scale", None) cache_k_out_scale = getattr(layer, "cache_k_out_scale", None) cache_v_out_scale = getattr(layer, "cache_v_out_scale", None) - cache_k_zp = getattr(self, "cache_k_zp", None) - cache_v_zp = getattr(self, "cache_v_zp", None) + cache_k_zp = getattr(layer, "cache_k_zp", None) + cache_v_zp = getattr(layer, "cache_v_zp", None) if layer.use_qk_norm: q_norm_weight = layer.q_norm_weight @@ -220,8 +220,8 @@ def forward_mixed( cache_v_scale, cache_k_out_scale, cache_v_out_scale, - cache_k_zp, - cache_v_zp, + cache_k_zp.astype("bfloat16") if cache_k_zp is not None else None, # for C8 + cache_v_zp.astype("bfloat16") if cache_v_zp is not None else None, # for C8 None, # shift None, # smooth q_norm_weight, diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index f2c37452ca4..cf365610eba 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -277,6 +277,26 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): default_initializer=paddle.nn.initializer.Constant(0), ), ) + # Set weight loader for up_gate_proj_weight and down_proj_weight and their scales in W4A8 case. + for weight_scale_name in self.added_scale_attrs: + set_weight_attrs( + getattr(layer, weight_scale_name), + { + "weight_loader": extra_weight_attrs.get( + "weight_loader", default_weight_loader(layer.fd_config) + ), + }, + ) + + for weight_name in self.added_weight_attrs: + set_weight_attrs( + getattr(layer, weight_name), + { + "weight_loader": extra_weight_attrs.get( + "weight_loader", default_weight_loader(layer.fd_config) + ), + }, + ) if self.moe_quant_type in ["w8a8", "w4a8"]: for in_scale_name in self.added_in_scale_attrs: @@ -289,6 +309,17 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): default_initializer=paddle.nn.initializer.Constant(0), ), ) + # Set weight_loader for offline in_scale + for in_scale_name in self.added_in_scale_attrs: + set_weight_attrs( + getattr(layer, in_scale_name), + { + "SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None}, + "weight_loader": extra_weight_attrs.get( + "weight_loader", default_weight_loader(layer.fd_config) + ), + }, + ) def process_loaded_weights(self, layer: nn.Layer, state_dict): up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py index 25044bc0939..6d96e2da374 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py @@ -30,6 +30,23 @@ from fastdeploy.model_executor.utils import set_weight_attrs +def _tp_shard_along_kv_heads( + loaded_weight: paddle.Tensor, + total_kv_heads: int, + tp_size: int, + tp_rank: int, + is_pre_sharded: bool, +) -> paddle.Tensor: + """Slice ``loaded_weight`` along the kv_heads dim for the current TP rank.""" + if tp_size <= 1 or is_pre_sharded: + return loaded_weight + assert total_kv_heads % tp_size == 0, f"num_kv_heads ({total_kv_heads}) must be divisible by tp_size ({tp_size})" + head_dim = loaded_weight.numel() // total_kv_heads + kv_heads_per_rank = total_kv_heads // tp_size + start = tp_rank * kv_heads_per_rank + return loaded_weight.reshape([total_kv_heads, head_dim])[start : start + kv_heads_per_rank, :] + + class XPUKvCacheQuantConfig(QuantConfigBase): """ quantization config for weight 4bits and activation fp8 @@ -42,6 +59,7 @@ def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_poi super().__init__() self.kv_cache_quant_type = kv_cache_quant_type self.is_channel_wise = is_channel_wise + self.has_zero_point = has_zero_point try: self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type) @@ -140,64 +158,62 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): if self.cache_quant_config.is_channel_wise: scale_shape = [layer.kv_num_heads * layer.head_dim] - layer.cache_k_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.cache_v_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - - set_weight_attrs( - layer.cache_k_scale, - { - **extra_weight_attrs, - }, - ) - set_weight_attrs( - layer.cache_v_scale, - { - **extra_weight_attrs, - }, - ) - - layer.cache_k_out_scale = layer.create_parameter( - shape=scale_shape, - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.cache_v_out_scale = layer.create_parameter( - shape=scale_shape, - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ) + # Build per-channel weight loaders (and the attrs that carry them) + # before creating params, so each param can be attached at creation site. + if self.cache_quant_config.is_channel_wise: + # C8 weight loader + fd_config = layer.fd_config + total_kv_heads = fd_config.model_config.num_key_value_heads + tp_size = fd_config.parallel_config.tensor_parallel_size + tp_rank = fd_config.parallel_config.tensor_parallel_rank + max_bound = self.cache_quant_config.max_bound + + def _shard(t): + return _tp_shard_along_kv_heads( + t, total_kv_heads, tp_size, tp_rank, fd_config.load_config.is_pre_sharded + ) + + def _kv_scale_weight_loader(param, loaded_weight, shard_id=None): + loaded_weight = _shard(get_tensor(loaded_weight).cast("float32")) + loaded_weight = paddle.clip(loaded_weight, min=1e-8) + param.copy_((max_bound / loaded_weight).reshape(param.shape).cast(param.dtype), False) + + def _kv_zp_weight_loader(param, loaded_weight, shard_id=None): + loaded_weight = _shard(get_tensor(loaded_weight).cast(param.dtype)) + param.copy_(loaded_weight.reshape(param.shape), False) + + scale_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_scale_weight_loader} + zp_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_zp_weight_loader} + else: + scale_weight_attrs = extra_weight_attrs + zp_weight_attrs = extra_weight_attrs - if self.cache_quant_config.has_zero_point: - layer.cache_k_zp = layer.create_parameter( - shape=scale_shape, - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.cache_v_zp = layer.create_parameter( + default_dtype = paddle.get_default_dtype() + + def _make_param(name, dtype, attrs=None): + # Note: shared loader auto-dispatches to k or v via the param name + # (loaders inspect cache_k_*/cache_v_* keys downstream). + param = layer.create_parameter( shape=scale_shape, - dtype="float32", + dtype=dtype, default_initializer=paddle.nn.initializer.Constant(0), ) - set_weight_attrs( - layer.cache_k_zp, - { - **extra_weight_attrs, - }, - ) - set_weight_attrs( - layer.cache_v_zp, - { - **extra_weight_attrs, - }, - ) + setattr(layer, name, param) + if attrs is not None: + set_weight_attrs(param, attrs) + return param + + # Quantization scales (write path): bf16 for the W4A8 attention kernel + _make_param("cache_k_scale", default_dtype, scale_weight_attrs) + _make_param("cache_v_scale", default_dtype, scale_weight_attrs) + + # Inverse-quantization scales (read path): fp32 for accuracy + _make_param("cache_k_out_scale", "float32") + _make_param("cache_v_out_scale", "float32") + + if self.cache_quant_config.has_zero_point: + _make_param("cache_k_zp", "float32", zp_weight_attrs) + _make_param("cache_v_zp", "float32", zp_weight_attrs) def process_loaded_weights(self, layer: nn.Layer, state_dict): """ @@ -220,9 +236,13 @@ def process_weights_after_loading(self, layer: nn.Layer): """ # cache_k_out_scale is the reciprocal of cache_k_scale if layer.cache_k_scale._is_initialized(): - layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale + layer.cache_k_out_scale.set_value( + self.cache_quant_config.max_bound / layer.cache_k_scale.cast("float32").reshape_([-1]) + ) if layer.cache_v_scale._is_initialized(): - layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale) + layer.cache_v_out_scale.set_value( + self.cache_quant_config.max_bound / layer.cache_v_scale.cast("float32").reshape_([-1]) + ) def apply(self, layer): """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index a194197fb2c..2927b6a92e9 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -307,6 +307,19 @@ def __init__( tp_size={self.tp_size}." ) + def _load_in_scale_weight(self, param, expert_id, loaded_weight): + # only spport ernie now + expert_param = param[expert_id - self.expert_id_offset] + loaded_weight = get_tensor(loaded_weight) + if len(expert_param.shape) != len(loaded_weight.shape): + loaded_weight = loaded_weight.reshape(expert_param.shape) + assert expert_param.shape == loaded_weight.shape, ( + f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" + ) + if expert_param.dtype != loaded_weight.dtype: + loaded_weight = loaded_weight.cast(expert_param.dtype) + param[expert_id - self.expert_id_offset].copy_(loaded_weight, False) + def weight_loader( self, param, @@ -339,6 +352,11 @@ def weight_loader( if weight_need_transpose: loaded_weight = loaded_weight.transpose([1, 0]) + if SHARD_ID_TO_SHARDED_DIM["gate"] is None and SHARD_ID_TO_SHARDED_DIM["up"] is None: + # in scale + self._load_in_scale_weight(param, expert_id, loaded_weight) + return + if shard_id is None: # 1.gate up fused in disk output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 5780edc1d1d..24a29e81383 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -276,4 +276,13 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: if quantization == "modelopt_fp4": method_to_config["modelopt_fp4"] = ModelOptNvFp4Config + from fastdeploy.platforms import current_platform + + # For XPU platform, use XPUKvCacheQuantConfig instead of KvCacheQuantConfig + if quantization == "kvcache" and current_platform.is_xpu(): + from fastdeploy.model_executor.layers.backends.xpu.quantization.kv_cache import ( + XPUKvCacheQuantConfig, + ) + + method_to_config["kvcache"] = XPUKvCacheQuantConfig return method_to_config[quantization] diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 3fb32b82ecb..4d71a96a8ec 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -597,6 +597,7 @@ def load_weights(self, weights_iterator) -> None: ("attn.cache_k_scale", "cachek_matmul.in_scale", None, None), ("attn.cache_v_scale", "cachev_matmul.in_scale", None, None), ("up_gate_proj_in_scale", "up_gate_proj.in_scale", None, None), + ("down_proj_in_scale", "down_proj.in_scale", None, None), ] expert_params_mapping = [] diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 3c8400986f5..d07e232f132 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -16,6 +16,7 @@ import importlib import importlib.util +import math import os import re from collections.abc import Mapping @@ -357,8 +358,8 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation loaded_weight = fd_cast(loaded_weight, param) - if param.shape != loaded_weight.shape: - # for e_score_correction_bias + if param.shape != loaded_weight.shape and math.prod(param.shape) == math.prod(loaded_weight.shape): + # for e_score_correction_bias and kv cache scale loaded_weight = loaded_weight.reshape(param.shape) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" @@ -439,7 +440,7 @@ def _get_unsupported_quant(): if current_platform.is_cuda(): return {"w4a8", "wint2"} elif current_platform.is_xpu(): - return {"w4a8", "w8a8"} + return {"w8a8"} return set() def _err_msg(msg: str) -> str: @@ -543,6 +544,10 @@ def rename_offline_ckpt_suffix_to_fd_suffix( ckpt_weight_suffix: "weight", ckpt_act_suffix: "in_scale", } + w4a8_suffix_map = { + ckpt_weight_suffix: "weight", + ckpt_act_suffix: "in_scale", + } moe_quant_type = "" dense_quant_type = "" if fd_config.quant_config is not None: @@ -560,8 +565,14 @@ def fn(loaded_weight_name, is_moe): fd_suffix_map = {} if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"): fd_suffix_map = fp8_suffix_map - if (is_moe and moe_quant_type == "tensor_wise_fp8") or (not is_moe and dense_quant_type == "tensor_wise_fp8"): + elif (is_moe and moe_quant_type == "tensor_wise_fp8") or ( + not is_moe and dense_quant_type == "tensor_wise_fp8" + ): fd_suffix_map = tensor_wise_fp8_suffix_map + elif is_moe and moe_quant_type in ("w4a8", "w4afp8"): + fd_suffix_map = w4a8_suffix_map + else: + fd_suffix_map = {} for ckpt_suffix, fd_suffix in fd_suffix_map.items(): if re.search(rf"{ckpt_suffix}$", loaded_weight_name): loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)