diff --git a/docs/CN/source/models/add_new_model.md b/docs/CN/source/models/add_new_model.md index 49b47ffa26..5d34cf747e 100755 --- a/docs/CN/source/models/add_new_model.md +++ b/docs/CN/source/models/add_new_model.md @@ -162,19 +162,6 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): self.tp_rank_: split_vob_size * (self.tp_rank_ + 1), :]) self.lm_head_weight_ = self.wte_weight_ return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return - ~~~ ***transformer_layer_weight.py*** @@ -204,30 +191,6 @@ class BloomTransformerLayerWeight(TransformerLayerWeight): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return def _load_qkvo_weights(self, weights): if f"h.{self.layer_num_}.input_layernorm.weight" in weights: diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index b7f6312a60..e86929a89f 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -367,17 +367,14 @@ PD 分离模式参数 .. option:: --quant_type 量化方法,可选值: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``vllm-fp8w8a8-b128`` + * ``deepgemm-fp8w8a8-b128`` * ``triton-fp8w8a8-block128`` + * ``awq`` + * ``awq_marlin`` * ``none`` (默认) .. option:: --quant_cfg @@ -389,13 +386,7 @@ PD 分离模式参数 .. option:: --vit_quant_type ViT 量化方法,可选值: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``none`` (默认) diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 2fc5d3e621..de7ecc84c3 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -49,13 +49,14 @@ LightLLM 支持以下几种部署模式: .. code-block:: bash # H200 单机 DeepSeek-R1 DP + EP 模式 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --dp 8 + --dp 8 \ + --enable_ep_moe **参数说明:** -- `MOE_MODE=EP`: 设置专家并行模式 +- `--enable_ep_moe`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 @@ -119,14 +120,14 @@ LightLLM 支持以下几种部署模式: # H200 多机 DeepSeek-R1 EP 模式 Node 0 # 使用方法: sh multi_node_ep_node0.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Node 1 启动命令:** @@ -135,14 +136,14 @@ LightLLM 支持以下几种部署模式: # H200 多机 DeepSeek-R1 EP 模式 Node 1 # 使用方法: sh multi_node_ep_node1.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -179,7 +180,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -189,7 +190,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 2732 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_prefill_microbatch_overlap @@ -202,7 +204,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -212,7 +214,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 12322 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_decode_microbatch_overlap @@ -269,7 +272,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -279,7 +282,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 2732 \ --disable_cudagraph \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_prefill_microbatch_overlap @@ -287,7 +291,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --host $host \ @@ -296,7 +300,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tp 8 \ --dp 8 \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_decode_microbatch_overlap diff --git a/docs/EN/source/models/add_new_model.md b/docs/EN/source/models/add_new_model.md index 6127dffaf7..7417c39cfe 100755 --- a/docs/EN/source/models/add_new_model.md +++ b/docs/EN/source/models/add_new_model.md @@ -162,18 +162,6 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): self.tp_rank_: split_vob_size * (self.tp_rank_ + 1), :]) self.lm_head_weight_ = self.wte_weight_ return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return ~~~ @@ -204,30 +192,6 @@ class BloomTransformerLayerWeight(TransformerLayerWeight): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return def _load_qkvo_weights(self, weights): if f"h.{self.layer_num_}.input_layernorm.weight" in weights: diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 18fe54c552..73cf12513a 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -359,17 +359,14 @@ Quantization Parameters .. option:: --quant_type Quantization method, optional values: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``vllm-fp8w8a8-b128`` + * ``deepgemm-fp8w8a8-b128`` * ``triton-fp8w8a8-block128`` + * ``awq`` + * ``awq_marlin`` * ``none`` (default) .. option:: --quant_cfg @@ -381,13 +378,7 @@ Quantization Parameters .. option:: --vit_quant_type ViT quantization method, optional values: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``none`` (default) diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 280a61ceb2..4c5a121dd6 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -49,13 +49,14 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. .. code-block:: bash # H200 Single node DeepSeek-R1 DP + EP Mode - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --dp 8 + --dp 8 \ + --enable_ep_moe **Parameter Description:** -- `MOE_MODE=EP`: Set expert parallelism mode +- `--enable_ep_moe`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp @@ -119,14 +120,14 @@ Suitable for deploying MoE models across multiple nodes. # H200 Multi-node DeepSeek-R1 EP Mode Node 0 # Usage: sh multi_node_ep_node0.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Node 1 Launch Command:** @@ -135,14 +136,14 @@ Suitable for deploying MoE models across multiple nodes. # H200 Multi-node DeepSeek-R1 EP Mode Node 1 # Usage: sh multi_node_ep_node1.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -179,7 +180,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -188,7 +189,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --port 8019 \ --nccl_port 2732 \ --disable_cudagraph \ - --pd_master_ip $pd_master_ip + --pd_master_ip $pd_master_ip \ + --enable_ep_moe **Step 3: Launch Decode Service** @@ -199,7 +201,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -209,7 +211,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --nccl_port 12322 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_decode_microbatch_overlap @@ -266,7 +269,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -276,7 +279,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --nccl_port 2732 \ --disable_cudagraph \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap @@ -284,7 +288,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --host $host \ @@ -293,7 +297,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tp 8 \ --dp 8 \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_decode_microbatch_overlap diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 3feed1ef46..12b2b0dfa8 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -4,7 +4,7 @@ from typing import Optional, TYPE_CHECKING from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 26d51af3db..e6405e4d74 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -103,20 +103,15 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - # 更连续的显存分配可以有更好的性能 - if self.max_total_token_num is None: - self._init_weights() - self._init_mem_manager() - else: - self._init_mem_manager() - self._init_weights() - + self._init_weights() + self._init_mem_manager() self._init_kv_move_buffer() self._check_mem_size() self._init_req_manager() self._init_infer_layer() self._init_some_value() self._init_custom() + self._load_hf_weights() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -179,6 +174,9 @@ def _init_weights(self, start_layer_index=0): ) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] + return + + def _load_hf_weights(self): load_hf_weights( self.data_type, weight_dir=self.weight_dir_, @@ -639,8 +637,6 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod assert model_input0.mem_indexes.is_cuda assert model_input1.mem_indexes.is_cuda - input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids - infer_state0 = self._create_inferstate(model_input0, 0) init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, @@ -670,9 +666,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() - model_output0, model_output1 = self._overlap_tpsp_context_forward( - input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 - ) + model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 0fa02780cb..8e884012d5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,13 +1,12 @@ from .base_weight import BaseWeight from .mm_weight import ( - MMWeightPack, MMWeightTpl, ROWMMWeight, - COLMMWeight, + KVROWNMMWeight, ROWBMMWeight, + COLMMWeight, ) -from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight -from .fused_moe_weight_tp import create_tp_moe_wegiht_obj -from .fused_moe_weight_ep import FusedMoeWeightEP +from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight +from .fused_moe.fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 3f8e1f50ab..2013d55be0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -1,23 +1,57 @@ import torch -from typing import Dict +from typing import Dict, Tuple from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id class TpAttSinkWeight(BaseWeightTpl): - def __init__(self, weight_name: str, data_type): + def __init__(self, all_q_head_num: int, weight_name: str, data_type): super().__init__() + self.all_q_head_num = all_q_head_num self.weight_name = weight_name self.data_type_ = data_type - self.weight: torch.Tensor = None + self._start_head_index, self._end_head_index = self._get_head_tp_split_params(all_head_num=self.all_q_head_num) + self._create_weight() + + def _create_weight(self): + self.weight = torch.empty( + (self._end_head_index - self._start_head_index,), dtype=self.data_type_, device="cuda" + ) + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name not in weights or self.weight is not None: + if self.weight_name not in weights: return t_weight = weights[self.weight_name] - start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) - self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) + self.weight = ( + t_weight[self._start_head_index : self._end_head_index].to(self.data_type_).cuda(get_current_device_id()) + ) + self.weight.load_ok = True def verify_load(self): - return self.weight is not None + return self.weight.load_ok + + def _get_head_tp_split_params(self, all_head_num: int) -> Tuple[int, int]: + """ + Docstring for _get_head_tp_split_params, + 一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。 + :param self: Description + :param weight: Description + :type weight: torch.Tensor + :return: Description + :rtype: Tuple[int, int] + """ + tp_head_num = all_head_num // self.tp_world_size_ + + if tp_head_num > 0: + start_head_index = self.tp_rank_ * tp_head_num + end_head_index = (self.tp_rank_ + 1) * tp_head_num + else: + # 当 tp_world_size 大于 all_head_num 时的特殊处理 + scale_size = self.tp_world_size_ // all_head_num + assert self.tp_world_size_ % all_head_num == 0 + start_head_index = self.tp_rank_ // scale_size + end_head_index = start_head_index + 1 + + return start_head_index, end_head_index diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 2cd8ea6ae0..714e7acf48 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -6,12 +6,17 @@ class BaseWeight(ABC): def __init__(self): + super().__init__() pass @abstractmethod def load_hf_weights(self, weights): pass + @abstractmethod + def _create_weight(self): + pass + @abstractmethod def verify_load(self) -> bool: pass @@ -19,6 +24,7 @@ def verify_load(self) -> bool: class BaseWeightTpl(BaseWeight): def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: torch.dtype = None): + super().__init__() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.device_id_ = get_current_device_id() @@ -30,29 +36,5 @@ def load_hf_weights(self, weights): def verify_load(self) -> bool: raise NotImplementedError("verify_load must implement this method") - def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]: - """ - Docstring for _get_head_tp_split_params, - 一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。 - :param self: Description - :param weight: Description - :type weight: torch.Tensor - :return: Description - :rtype: Tuple[int, int] - """ - assert weight.ndim == 2 - - all_head_num = weight.shape[0] - tp_head_num = all_head_num // self.tp_world_size_ - - if tp_head_num > 0: - start_head_index = self.tp_rank_ * tp_head_num - end_head_index = (self.tp_rank_ + 1) * tp_head_num - else: - # 当 tp_world_size 大于 all_head_num 时的特殊处理 - scale_size = self.tp_world_size_ // all_head_num - assert self.tp_world_size_ % all_head_num == 0 - start_head_index = self.tp_rank_ // scale_size - end_head_index = start_head_index + 1 - - return start_head_index, end_head_index + def _create_weight(self): + raise NotImplementedError("create_weight must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index fc018267fa..d94a4c709b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -2,46 +2,62 @@ import numpy as np from typing import Dict, Optional from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id +from .platform_op import PlatformAwareOp from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp -logger = init_logger(__name__) - -class EmbeddingWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type): +class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): super().__init__() + self.dim = dim + self.vocab_size = vocab_size + # 计算 split_indexes + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) self.weight_name: str = weight_name self.data_type_ = data_type - self.weight: torch.Tensor = None + self._create_weight() + + def _create_weight(self): + tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id + self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name not in weights or self.weight is not None: + if self.weight_name not in weights: return - t_weight = weights[self.weight_name] # init some params - self.vocab_size = len(t_weight) - split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) - self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) - - logger.info(f"loaded weight vocab_size: {self.vocab_size}") - - self.weight = ( - t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_).cuda(get_current_device_id()) - ) + loaded_vocab_size = len(t_weight) + assert ( + loaded_vocab_size == self.vocab_size + ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" + self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self.weight.load_ok = True def verify_load(self): - return self.weight is not None - - def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + return self.weight.load_ok + + def _native_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + adjusted_ids = input_ids - self.tp_vocab_start_id + adjusted_ids = torch.clamp(adjusted_ids, 0, self.weight.shape[0] - 1) + result = torch.nn.functional.embedding(adjusted_ids, self.weight) + if out is not None: + out.copy_(result) + return out + return result + + def _triton_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: if out is None: out = alloc_func( (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device ) - embedding_kernel( input_ids=input_ids, weight=self.weight, @@ -49,10 +65,73 @@ def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, vob_end_id=self.tp_vocab_end_id, out=out, ) - return out - def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + def _cuda_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._triton_forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + + def _musa_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + + def __call__( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + + +class LMHeadWeight(EmbeddingWeight): + def __init__( + self, + dim: int, + vocab_size: int, + weight_name: str, + data_type: torch.dtype, + embedding_weight: Optional[EmbeddingWeight] = None, + ): + self._embedding_weight = embedding_weight + super().__init__(dim=dim, vocab_size=vocab_size, weight_name=weight_name, data_type=data_type) + + def _create_weight(self): + if self._embedding_weight is not None: + self.weight = self._embedding_weight.weight + return + super()._create_weight() + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + # When set tile_embedding=True, no need to load - EmbeddingWeight already loaded it + if self._embedding_weight is not None: + return + if self.weight_name not in weights: + return + t_weight = weights[self.weight_name] + loaded_vocab_size = len(t_weight) + assert ( + loaded_vocab_size == self.vocab_size + ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" + self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self.weight.load_ok = True + + def verify_load(self): + return self.weight.load_ok + + def _native_forward( + self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 + result = torch.mm(self.weight, input) + if out is not None: + out.copy_(result) + return out + return result + + def _cuda_forward( + self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: assert input.ndim == 2 if out is None: out = alloc_func( @@ -60,49 +139,71 @@ def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc dtype=input.dtype, device=input.device, ) - torch.mm(self.weight, input, out=out) return out - -class LMHeadWeight(EmbeddingWeight): - def __init__(self, weight_name, data_type): - super().__init__(weight_name, data_type) + def __call__(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty) -> torch.Tensor: + return self._forward(input=input, out=out, alloc_func=alloc_func) -class NoTpPosEmbeddingWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type): +class NoTpPosEmbeddingWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, data_type: torch.dtype): super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings self.weight_name: str = weight_name self.data_type_ = data_type - self.weight: torch.Tensor = None self.tp_world_size_ = 1 self.tp_rank_ = 0 + self._create_weight() + + def _create_weight(self): + self.weight: torch.Tensor = torch.empty( + self.max_position_embeddings, self.dim, dtype=self.data_type_, device=self.device_id_ + ) + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name not in weights or self.weight is not None: + if self.weight_name not in weights: return - t_weight = weights[self.weight_name] - self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) - self.end_position_id: int = t_weight.shape[0] - logger.info(f"loaded weight end_position_id: {self.end_position_id}") + loaded_max_position_embeddings = t_weight.shape[0] + assert ( + loaded_max_position_embeddings == self.max_position_embeddings + ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True def verify_load(self): - return self.weight is not None - - def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + return self.weight.load_ok + + def _native_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + # Use PyTorch native embedding + result = torch.nn.functional.embedding(input_ids, self.weight) + if out is not None: + out.copy_(result) + return out + return result + + def _cuda_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: if out is None: out = alloc_func( (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device ) - embedding_kernel( input_ids=input_ids, weight=self.weight, vob_start_id=0, - vob_end_id=self.end_position_id, + vob_end_id=self.max_position_embeddings, out=out, ) - return out + + def __call__( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py similarity index 68% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py index b53200d4c8..749400c8d8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py @@ -1,6 +1,6 @@ import numpy as np import torch -from .fused_moe_weight_ep import FusedMoeWeightEP +from .fused_moe_weight import FusedMoeWeight from lightllm.utils.log_utils import init_logger from typing import Dict @@ -10,7 +10,7 @@ class FusedMoeWeightEPAutoRedundancy: def __init__( self, - ep_fused_moe_weight: FusedMoeWeightEP, + ep_fused_moe_weight: FusedMoeWeight, ) -> None: super().__init__() self._ep_w = ep_fused_moe_weight @@ -25,12 +25,13 @@ def prepare_redundancy_experts( ): expert_counter = self._ep_w.routed_expert_counter_tensor.detach().cpu().numpy() logger.info( - f"layer_index {self._ep_w.layer_num} global_rank {self._ep_w.global_rank_} expert_counter: {expert_counter}" + f"layer_index {self._ep_w.layer_num_} global_rank {self._ep_w.global_rank_}" + f" expert_counter: {expert_counter}" ) self._ep_w.routed_expert_counter_tensor.fill_(0) - - start_expert_id = self._ep_w.ep_n_routed_experts * self._ep_w.global_rank_ - no_redundancy_expert_ids = list(range(start_expert_id, start_expert_id + self._ep_w.ep_n_routed_experts)) + ep_n_routed_experts = self._ep_w.n_routed_experts // self._ep_w.global_world_size + start_expert_id = ep_n_routed_experts * self._ep_w.global_rank_ + no_redundancy_expert_ids = list(range(start_expert_id, start_expert_id + ep_n_routed_experts)) # 统计 0 rank 上的全局 topk 冗余信息,帮助导出一份全局可用的静态使用的冗余专家静态配置。 if self._ep_w.global_rank_ == 0: @@ -44,7 +45,7 @@ def prepare_redundancy_experts( self.redundancy_expert_ids = list(np.argsort(expert_counter)[-self.redundancy_expert_num :]) logger.info( - f"layer_index {self._ep_w.layer_num} global_rank {self._ep_w.global_rank_}" + f"layer_index {self._ep_w.layer_num_} global_rank {self._ep_w.global_rank_}" f" new select redundancy_expert_ids : {self.redundancy_expert_ids}" ) @@ -55,7 +56,7 @@ def prepare_redundancy_experts( self.experts_gate_proj_scales = [None] * self.redundancy_expert_num self.w2_list = [None] * self.redundancy_expert_num self.w2_scale_list = [None] * self.redundancy_expert_num - self.w1 = [None, None] # weight, weight_scale + self.w13 = [None, None] # weight, weight_scale self.w2 = [None, None] # weight, weight_scale return topk_redundancy_expert_ids @@ -73,13 +74,12 @@ def load_hf_weights(self, weights): if w2_weight in weights: self.w2_list[i] = weights[w2_weight] - if self._ep_w.quantized_weight: - self._load_weight_scale(weights) + self._load_weight_scale(weights) self._fuse() def _fuse(self): - if self._ep_w.quantized_weight: - self._fuse_weight_scale() + self._fuse_weight_scale() + with self._ep_w.lock: if ( hasattr(self, "experts_up_projs") @@ -93,23 +93,38 @@ def _fuse(self): dtype = self.experts_gate_projs[0].dtype total_expert_num = self.redundancy_expert_num - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") + w13 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") for i_experts in range(self.redundancy_expert_num): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] + w13[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] + w13[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self._ep_w.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self._ep_w.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale + if self._ep_w.quant_method._check_weight_need_quanted(weight=w13): + w13_pack, _ = self._ep_w.quant_method.create_moe_weight( + out_dims=[gate_out_dim + up_out_dim], + in_dim=1, + dtype=self._ep_w.data_type_, + device_id=self._ep_w.device_id_, + num_experts=self.redundancy_expert_num, + ) + self._ep_w.quant_method.quantize(w13, w13_pack) + w2_pack, _ = self._ep_w.quant_method.create_moe_weight( + out_dims=[inter_shape], + in_dim=hidden_size, + dtype=self._ep_w.data_type_, + device_id=self._ep_w.device_id_, + num_experts=self.redundancy_expert_num, + ) + self._ep_w.quant_method.quantize(w2, w2_pack) + + self.w13[0] = w13_pack.weight + self.w13[1] = w13_pack.weight_scale + self.w2[0] = w2_pack.weight + self.w2[1] = w2_pack.weight_scale else: - self.w1[0] = w1 + self.w13[0] = w13 self.w2[0] = w2 delattr(self, "w2_list") delattr(self, "experts_up_projs") @@ -128,18 +143,18 @@ def _fuse_weight_scale(self): assert gate_in_dim == up_in_dim dtype = self.experts_gate_proj_scales[0].dtype total_expert_num = self.redundancy_expert_num - w1_scale = torch.empty( + w13_scale = torch.empty( (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" ) for i_experts in range(self.redundancy_expert_num): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] + w13_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] + w13_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( len(self.w2_scale_list), inter_shape, hidden_size ) - self.w1[1] = w1_scale + self.w13[1] = w13_scale self.w2[1] = w2_scale delattr(self, "w2_scale_list") delattr(self, "experts_up_proj_scales") @@ -149,15 +164,10 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: # 加载冗余专家的scale参数 for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): i_experts = redundant_expert_id - w1_scale = ( - f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w1_weight_name}.{self._ep_w.weight_scale_suffix}" - ) - w2_scale = ( - f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w2_weight_name}.{self._ep_w.weight_scale_suffix}" - ) - w3_scale = ( - f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.{self._ep_w.weight_scale_suffix}" - ) + weight_scale_suffix = self._ep_w.quant_method.weight_scale_suffix + w1_scale = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w1_weight_name}.{weight_scale_suffix}" + w2_scale = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w2_weight_name}.{weight_scale_suffix}" + w3_scale = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.{weight_scale_suffix}" if w1_scale in weights: self.experts_gate_proj_scales[i] = weights[w1_scale] if w3_scale in weights: @@ -166,14 +176,14 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: self.w2_scale_list[i] = weights[w2_scale] def commit(self): - for index, dest_tensor in enumerate(self._ep_w.w1): + for index, dest_tensor in enumerate([self._ep_w.w13.weight, self._ep_w.w13.weight_scale]): if dest_tensor is not None: assert isinstance( dest_tensor, torch.Tensor ), f"dest_tensor should be a torch.Tensor, but got {type(dest_tensor)}" - dest_tensor[-self.redundancy_expert_num :, :, :] = self.w1[index][:, :, :] + dest_tensor[-self.redundancy_expert_num :, :, :] = self.w13[index][:, :, :] - for index, dest_tensor in enumerate(self._ep_w.w2): + for index, dest_tensor in enumerate([self._ep_w.w2.weight, self._ep_w.w2.weight_scale]): if dest_tensor is not None: assert isinstance( dest_tensor, torch.Tensor diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py new file mode 100644 index 0000000000..6bcf7fc03c --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -0,0 +1,390 @@ +import torch +import threading +from typing import Dict, Any, Optional, Tuple, List +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( + get_row_slice_mixin, + get_col_slice_mixin, + SliceMixinTpl, +) +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl +from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args +from lightllm.utils.dist_utils import get_global_world_size, get_global_rank +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class FusedMoeWeight(BaseWeightTpl): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + hidden_size: int, + moe_intermediate_size: int, + data_type: torch.dtype, + quant_method: QuantizationMethod = None, + num_fused_shared_experts: int = 0, + layer_num: int = 0, + network_config: Dict[str, Any] = None, + ) -> None: + super().__init__(data_type=data_type) + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + self.layer_num_ = layer_num + self.global_rank_ = get_global_rank() + self.global_world_size = get_global_world_size() + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.quant_method = quant_method + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.n_routed_experts = n_routed_experts + self.num_fused_shared_experts = num_fused_shared_experts + self._init_config(network_config) + self._init_redundancy_expert_params() + self._init_parallel_params() + self.fuse_moe_impl = select_fuse_moe_impl(self.quant_method, self.enable_ep_moe)( + n_routed_experts=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + routed_scaling_factor=self.routed_scaling_factor, + quant_method=self.quant_method, + redundancy_expert_num=self.redundancy_expert_num, + redundancy_expert_ids_tensor=self.redundancy_expert_ids_tensor, + routed_expert_counter_tensor=self.routed_expert_counter_tensor, + auto_update_redundancy_expert=self.auto_update_redundancy_expert, + ) + self.lock = threading.Lock() + self._create_weight() + + def _init_config(self, network_config: Dict[str, Any]): + self.n_group = network_config.get("n_group", 0) + self.use_grouped_topk = self.n_group > 0 + self.norm_topk_prob = network_config["norm_topk_prob"] + self.topk_group = network_config.get("topk_group", 0) + self.num_experts_per_tok = network_config["num_experts_per_tok"] + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.scoring_func = network_config.get("scoring_func", "softmax") + + def _init_redundancy_expert_params(self): + self.redundancy_expert_num = get_redundancy_expert_num() + self.redundancy_expert_ids = get_redundancy_expert_ids(self.layer_num_) + self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert + self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") + self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") + # TODO: find out the reason of failure of deepep when redundancy_expert_num is 1. + assert self.redundancy_expert_num != 1, "redundancy_expert_num can not be 1 for some unknown hang of deepep." + + def _init_parallel_params(self): + if self.enable_ep_moe: + self.tp_rank_ = 0 + self.tp_world_size_ = 1 + self.row_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) + self.col_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) + self.local_n_routed_experts = self.n_routed_experts + self.num_fused_shared_experts + self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ + if self.enable_ep_moe: + assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" + logger.info( + f"global_rank {self.global_rank_} layerindex {self.layer_num_} " + f"redundancy_expertids: {self.redundancy_expert_ids}" + ) + self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num + n_experts_per_rank = self.n_routed_experts // self.global_world_size + start_expert_id = self.global_rank_ * n_experts_per_rank + self.local_expert_ids = ( + list(range(start_expert_id, start_expert_id + n_experts_per_rank)) + self.redundancy_expert_ids + ) + self.expert_idx_to_local_idx = { + expert_idx: expert_idx - start_expert_id for expert_idx in self.local_expert_ids[:n_experts_per_rank] + } + self.redundancy_expert_idx_to_local_idx = { + redundancy_expert_idx: n_experts_per_rank + i + for (i, redundancy_expert_idx) in enumerate(self.redundancy_expert_ids) + } + else: + self.local_expert_ids = list(range(self.n_routed_experts + self.num_fused_shared_experts)) + self.expert_idx_to_local_idx = {expert_idx: i for (i, expert_idx) in enumerate(self.local_expert_ids)} + self.rexpert_idx_to_local_idx = {} + + def experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ) -> torch.Tensor: + """Backward compatible method that routes to platform-specific implementation.""" + return self.fuse_moe_impl( + input_tensor=input_tensor, + router_logits=router_logits, + w13=self.w13, + w2=self.w2, + correction_bias=self.e_score_correction_bias, + scoring_func=self.scoring_func, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + is_prefill=is_prefill, + ) + + def low_latency_dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ): + assert self.enable_ep_moe, "low_latency_dispatch is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.low_latency_dispatch( + hidden_states=hidden_states, + router_logits=router_logits, + e_score_correction_bias=self.e_score_correction_bias, + use_grouped_topk=self.use_grouped_topk, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + topk_group=self.topk_group, + n_group=self.n_group, + scoring_func=self.scoring_func, + ) + + def select_experts_and_quant_input( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ): + assert self.enable_ep_moe, "select_experts_and_quant_input is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.select_experts_and_quant_input( + hidden_states=hidden_states, + router_logits=router_logits, + e_score_correction_bias=self.e_score_correction_bias, + w13=self.w13, + use_grouped_topk=self.use_grouped_topk, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + topk_group=self.topk_group, + n_group=self.n_group, + scoring_func=self.scoring_func, + ) + + def dispatch( + self, + qinput_tensor: Tuple[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_event: Optional[Any] = None, + ): + assert self.enable_ep_moe, "dispatch is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.dispatch( + qinput_tensor=qinput_tensor, + topk_idx=topk_idx, + topk_weights=topk_weights, + overlap_event=overlap_event, + ) + + def masked_group_gemm( + self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int + ): + assert self.enable_ep_moe, "masked_group_gemm is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.masked_group_gemm( + recv_x=recv_x, + w13=self.w13, + w2=self.w2, + masked_m=masked_m, + dtype=dtype, + expected_m=expected_m, + ) + + def prefilled_group_gemm( + self, + num_recv_tokens_per_expert_list, + recv_x: Tuple[torch.Tensor], + recv_topk_idx: torch.Tensor, + recv_topk_weights: torch.Tensor, + hidden_dtype=torch.bfloat16, + ): + assert self.enable_ep_moe, "prefilled_group_gemm is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.prefilled_group_gemm( + num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, + recv_x=recv_x, + recv_topk_idx=recv_topk_idx, + recv_topk_weights=recv_topk_weights, + w13=self.w13, + w2=self.w2, + hidden_dtype=hidden_dtype, + ) + + def low_latency_combine( + self, + gemm_out_b: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Any, + ): + assert self.enable_ep_moe, "low_latency_combine is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.low_latency_combine( + gemm_out_b=gemm_out_b, + topk_idx=topk_idx, + topk_weights=topk_weights, + handle=handle, + ) + + def combine( + self, + gemm_out_b: torch.Tensor, + handle: Any, + overlap_event: Optional[Any] = None, + ): + assert self.enable_ep_moe, "combine is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.combine( + gemm_out_b=gemm_out_b, + handle=handle, + overlap_event=overlap_event, + ) + + def load_hf_weights(self, weights): + # Load bias + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + self._load_weight(self.expert_idx_to_local_idx, weights) + if self.redundancy_expert_num > 0: + self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) + + def verify_load(self): + return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + + def _create_weight(self): + intermediate_size = self.split_inter_size + self.e_score_correction_bias = None + # Create e_score_correction_bias + if self.e_score_correction_bias_name: + self.e_score_correction_bias = torch.empty( + (self.n_routed_experts,), + dtype=self.data_type_, + device=f"cuda:{self.device_id_}", + ) + + self.w13, w13_param_list = self.quant_method.create_moe_weight( + out_dims=[intermediate_size, intermediate_size], + in_dim=self.hidden_size, + dtype=self.data_type_, + device_id=self.device_id_, + num_experts=self.local_n_routed_experts, + ) + self.w2, _ = self.quant_method.create_moe_weight( + out_dims=[self.hidden_size], + in_dim=intermediate_size, + dtype=self.data_type_, + device_id=self.device_id_, + num_experts=self.local_n_routed_experts, + ) + self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) + self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) + self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) + + def _get_expert_weight_list(self, weight_pack: WeightPack): + weight_list = [] + for idx in range(self.local_n_routed_experts): + expert_weight = weight_pack.get_expert(idx) + weight_list.append(expert_weight) + return weight_list + + def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): + + # Load each expert with TP slicing + for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): + with self.lock: + self._load_expert(expert_idx, local_expert_idx, weights) + self._load_expert_scale( + expert_idx, + local_expert_idx, + weights, + ) + self._load_expert_zero_point( + expert_idx, + local_expert_idx, + weights, + ) + + def _load_expert( + self, + expert_idx: int, + local_expert_idx: int, + weights: Dict[str, torch.Tensor], + ): + w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" + w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" + w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight + if w1_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w1_weight]), self.w1_list[local_expert_idx]) + if w3_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w3_weight]), self.w3_list[local_expert_idx]) + if w2_weight in weights: + self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + + def _load_expert_scale( + self, + expert_idx: int, + local_expert_idx: int, + weights: Dict[str, torch.Tensor], + ): + w1_scale = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_scale_suffix}" + row_slice_func = self.row_slicer._slice_weight_scale + col_slice_func = self.col_slicer._slice_weight_scale + if w1_scale in weights: + self.quant_method.load_weight_scale(row_slice_func(weights[w1_scale]), self.w1_list[local_expert_idx]) + if w3_scale in weights: + self.quant_method.load_weight_scale(row_slice_func(weights[w3_scale]), self.w3_list[local_expert_idx]) + if w2_scale in weights: + self.quant_method.load_weight_scale(col_slice_func(weights[w2_scale]), self.w2_list[local_expert_idx]) + + def _load_expert_zero_point( + self, + expert_idx: int, + local_expert_idx: int, + weights: Dict[str, torch.Tensor], + ): + w1_zero_point = ( + f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_zero_point_suffix}" + ) + w2_zero_point = ( + f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_zero_point_suffix}" + ) + w3_zero_point = ( + f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_zero_point_suffix}" + ) + row_slice_func = self.row_slicer._slice_weight_zero_point + col_slice_func = self.col_slicer._slice_weight_zero_point + if w1_zero_point in weights: + self.quant_method.load_weight_zero_point( + row_slice_func(weights[w1_zero_point]), self.w1_list[local_expert_idx] + ) + if w3_zero_point in weights: + self.quant_method.load_weight_zero_point( + row_slice_func(weights[w3_zero_point]), self.w3_list[local_expert_idx] + ) + if w2_zero_point in weights: + self.quant_method.load_weight_zero_point( + col_slice_func(weights[w2_zero_point]), self.w2_list[local_expert_idx] + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py similarity index 73% rename from lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index df72cc6208..6ed0cef0b4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -3,9 +3,10 @@ import threading from typing import Optional, Tuple, List, Dict, Any -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -30,41 +31,43 @@ ] -class GPTOSSFusedMoeWeightTP(FusedMoeWeightTP): +class GPTOSSFusedMoeWeightTP(FusedMoeWeight): def __init__( self, - gate_up_proj_name: str, # diff with FusedMoeWeightTP + gate_up_proj_name: str, down_proj_name: str, e_score_correction_bias_name: str, weight_prefix: str, n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, + hidden_size: int, + moe_intermediate_size: int, data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - world_size: int = 1, # diff with FusedMoeWeightTP - quant_cfg: Quantcfg = None, + quant_method: QuantizationMethod = None, + num_fused_shared_experts: int = 0, + layer_num: int = 0, + network_config: Dict[str, Any] = None, ) -> None: + network_config["norm_topk_prob"] = None super().__init__( - gate_up_proj_name, - down_proj_name, - gate_up_proj_name, - e_score_correction_bias_name, - weight_prefix, - n_routed_experts, - num_fused_shared_experts, - split_inter_size, - data_type, - network_config, - layer_num, - quant_cfg, + gate_proj_name=gate_up_proj_name, + down_proj_name=down_proj_name, + up_proj_name=gate_up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + data_type=data_type, + quant_method=quant_method, + num_fused_shared_experts=num_fused_shared_experts, + layer_num=layer_num, + network_config=network_config, ) + self.hidden_size = network_config["hidden_size"] self.alpha = 1.702 self.limit = 7.0 - self.tp_world_size_ = world_size self.w1_bias = None self.w2_bias = None @@ -77,6 +80,12 @@ def __init__( self._gate_up_scales_name = f"{weight_prefix}.{gate_up_proj_name}_scales" return + def _create_weight(self): + """ + 因为加载方式比较特殊,不在这里创建weight。 + """ + pass + def _fuse_weight_scale(self): assert False, "Not implemented for GPT-OSS." @@ -116,22 +125,38 @@ def load_hf_weights(self, weights): w2_bias = weights[self._down_bias_name] self.w2_bias = self._cuda(w2_bias) - def router(self, router_logits, top_k): + def verify_load(self): + assert self.w1 is not None and self.w2 is not None + return True + + def _router(self, router_logits, top_k): router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_top_value, router_indices - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - topk_weights, topk_ids = self.router(router_logits, top_k) + def experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ): + + topk_weights, topk_ids = self._router(router_logits, top_k) w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None + use_fp8_w8a8 = False # TODO: disable fp8 for GPT-OSS for now - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_experts output_tensor = fused_experts( - hidden_states=input_tensor.to(torch.bfloat16), + hidden_states=input_tensor.to(w1.dtype), w1=w1, w2=w2, topk_weights=topk_weights, @@ -201,3 +226,6 @@ def _convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) del blocks, scales, lut return out.transpose(1, 2).contiguous() + + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py new file mode 100644 index 0000000000..67bb90e4ef --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py @@ -0,0 +1,14 @@ +from lightllm.common.quantization.quantize_method import QuantizationMethod +from .triton_impl import FuseMoeTriton +from .marlin_impl import FuseMoeMarlin +from .deepgemm_impl import FuseMoeDeepGEMM + + +def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool): + if enable_ep_moe: + return FuseMoeDeepGEMM + + if quant_method.method_name == "awq_marlin": + return FuseMoeMarlin + else: + return FuseMoeTriton diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py new file mode 100644 index 0000000000..00587ac185 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -0,0 +1,66 @@ +import torch +from abc import abstractmethod +from lightllm.common.quantization.quantize_method import ( + WeightPack, + QuantizationMethod, +) +from typing import Optional +from lightllm.utils.dist_utils import ( + get_global_rank, + get_global_world_size, +) + + +class FuseMoeBaseImpl: + def __init__( + self, + n_routed_experts: int, + num_fused_shared_experts: int, + routed_scaling_factor: float, + quant_method: QuantizationMethod, + redundancy_expert_num: int, + redundancy_expert_ids_tensor: torch.Tensor, + routed_expert_counter_tensor: torch.Tensor, + auto_update_redundancy_expert: bool, + ): + self.n_routed_experts = n_routed_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = routed_scaling_factor + self.quant_method = quant_method + self.global_rank_ = get_global_rank() + self.global_world_size_ = get_global_world_size() + self.ep_n_routed_experts = self.n_routed_experts // self.global_world_size_ + self.total_expert_num_contain_redundancy = ( + self.n_routed_experts + redundancy_expert_num * self.global_world_size_ + ) + + # redundancy expert related + self.redundancy_expert_num = redundancy_expert_num + self.redundancy_expert_ids_tensor = redundancy_expert_ids_tensor + self.routed_expert_counter_tensor = routed_expert_counter_tensor + self.auto_update_redundancy_expert = auto_update_redundancy_expert + + # workspace for kernel optimization + self.workspace = self.create_workspace() + + @abstractmethod + def create_workspace(self): + pass + + @abstractmethod + def __call__( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ) -> torch.Tensor: + pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py new file mode 100644 index 0000000000..f00d572d9d --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -0,0 +1,336 @@ +import torch +from typing import Optional, Tuple, Any +from .triton_impl import FuseMoeTriton +from lightllm.distributed import dist_group_manager +from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( + fused_experts_impl, + masked_group_gemm, + _deepgemm_grouped_fp8_nt_contiguous, +) +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( + per_token_group_quant_fp8, + tma_align_input_scale, +) +from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair + + +class FuseMoeDeepGEMM(FuseMoeTriton): + def _select_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + correction_bias: Optional[torch.Tensor], + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + scoring_func: str, + ): + """Select experts and return topk weights and ids.""" + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.redundancy_expert_num > 0: + redundancy_topk_ids_repair( + topk_ids=topk_ids, + redundancy_expert_ids=self.redundancy_expert_ids_tensor, + ep_expert_num=self.ep_n_routed_experts, + global_rank=self.global_rank_, + expert_counter=self.routed_expert_counter_tensor, + enable_counter=self.auto_update_redundancy_expert, + ) + return topk_weights, topk_ids + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + ): + + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = self.quant_method.method_name != "none" + output = fused_experts_impl( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_idx=topk_ids.to(torch.long), + num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy + buffer=dist_group_manager.ep_buffer, + is_prefill=is_prefill, + use_fp8_w8a8=use_fp8_w8a8, + use_fp8_all2all=use_fp8_w8a8, + use_int8_w8a16=False, # default to False + w1_scale=w13_scale, + w2_scale=w2_scale, + previous_event=None, # for overlap + ) + return output + + def low_latency_dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + e_score_correction_bias: torch.Tensor, + use_grouped_topk: bool, + num_experts_per_tok: int, + norm_topk_prob: bool, + topk_group: int, + n_group: int, + scoring_func: str, + ): + topk_weights, topk_idx = self._select_experts( + input_tensor=hidden_states, + router_logits=router_logits, + correction_bias=e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=num_experts_per_tok, + renormalize=norm_topk_prob, + topk_group=topk_group, + num_expert_group=n_group, + scoring_func=scoring_func, + ) + + topk_idx = topk_idx.to(torch.long) + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + use_fp8_w8a8 = self.quant_method.method_name != "none" + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + self.total_expert_num_contain_redundancy, + use_fp8=use_fp8_w8a8, + async_finish=False, + return_recv_hook=True, + ) + return recv_x, masked_m, topk_idx, topk_weights, handle, hook + + def select_experts_and_quant_input( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + e_score_correction_bias: torch.Tensor, + w13: WeightPack, + use_grouped_topk: bool, + num_experts_per_tok: int, + norm_topk_prob: bool, + topk_group: int, + n_group: int, + scoring_func: str, + ): + topk_weights, topk_idx = self._select_experts( + input_tensor=hidden_states, + router_logits=router_logits, + correction_bias=e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=num_experts_per_tok, + renormalize=norm_topk_prob, + topk_group=topk_group, + num_expert_group=n_group, + scoring_func=scoring_func, + ) + w13_weight, w13_scale = w13.weight, w13.weight_scale + block_size_k = 0 + if w13_weight.ndim == 3: + block_size_k = w13_weight.shape[2] // w13_scale.shape[2] + assert block_size_k == 128, "block_size_k must be 128" + qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) + return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + + def dispatch( + self, + qinput_tensor: Tuple[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_event: Optional[Any] = None, + ): + buffer = dist_group_manager.ep_buffer + # get_dispatch_layout + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event, + ) = buffer.get_dispatch_layout( + topk_idx, + self.total_expert_num_contain_redundancy, + previous_event=overlap_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + qinput_tensor, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + expert_alignment=128, + ) + + def hook(): + event.current_stream_wait() + + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + + def masked_group_gemm( + self, + recv_x: Tuple[torch.Tensor], + w13: WeightPack, + w2: WeightPack, + masked_m: torch.Tensor, + dtype: torch.dtype, + expected_m: int, + ): + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + return masked_group_gemm( + recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m + ) + + def prefilled_group_gemm( + self, + num_recv_tokens_per_expert_list, + recv_x: Tuple[torch.Tensor], + recv_topk_idx: torch.Tensor, + recv_topk_weights: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + hidden_dtype=torch.bfloat16, + ): + device = recv_x[0].device + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + _, K = recv_x[0].shape + _, N, _ = w13_weight.shape + block_size = self.quant_method.block_size + # scatter + all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + # gather_out shape [recive_num_tokens, hidden] + gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype) + if all_tokens > 0: + input_tensor = [ + torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype), + torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32), + ] + # when m_indices is filled ok. + # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] + # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] + # ... + m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) + # output_index shape [recive_num_tokens, topk_num] + # output_index use to show the token index in input_tensor + output_index = torch.empty_like(recv_topk_idx) + + num_recv_tokens_per_expert = torch.tensor( + num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + ).cuda(non_blocking=True) + + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) + + ep_scatter( + recv_x[0], + recv_x[1], + recv_topk_idx, + num_recv_tokens_per_expert, + expert_start_loc, + input_tensor[0], + input_tensor[1], + m_indices, + output_index, + ) + input_tensor[1] = tma_align_input_scale(input_tensor[1]) + # groupgemm (contiguous layout) + gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) + + _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + + # silu_and_mul_fwd + qaunt + # TODO fused kernel + silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) + + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( + silu_out, block_size, dtype=w13_weight.dtype, column_major_scales=True, scale_tma_aligned=True + ) + + # groupgemm (contiguous layout) + gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) + + _deepgemm_grouped_fp8_nt_contiguous( + (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices + ) + # gather and local reduce + ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) + else: + ######################################## warning ################################################## + # here is used to match autotune feature, make moe model run same triton kernel in different rank. + # in some special case, one rank will recv 0 token, so add a token to make it run triton kernel. + if Autotuner.is_autotune_warmup(): + _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) + _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + _gemm_out_a, _silu_out = None, None + + return gather_out + + def low_latency_combine( + self, + gemm_out_b: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Any, + ): + combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True + ) + return combined_x, hook + + def combine( + self, + gemm_out_b: torch.Tensor, + handle: Any, + overlap_event: Optional[Any] = None, + ): + # normal combine + combined_x, _, event = dist_group_manager.ep_buffer.combine( + gemm_out_b, + handle, + topk_weights=None, + async_finish=True, + previous_event=overlap_event, + allocate_on_comm_stream=True, + ) + + def hook(): + event.current_stream_wait() + + return combined_x, hook diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py new file mode 100644 index 0000000000..6391a10800 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -0,0 +1,61 @@ +import torch +from .triton_impl import FuseMoeTriton +from lightllm.common.quantization.quantize_method import ( + WeightPack, +) +from lightllm.common.quantization.awq import ( + AWQMARLINW4A16QuantizationMethod, +) +from typing import Optional + + +class FuseMoeMarlin(FuseMoeTriton): + def create_workspace(self): + from lightllm.utils.vllm_utils import HAS_VLLM + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + ) + + return marlin_make_workspace_new(torch.device("cuda"), 4) + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + ): + + w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point + w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + self.quant_method: AWQMARLINW4A16QuantizationMethod = self.quant_method + + fused_marlin_moe( + input_tensor, + w1_weight, + w2_weight, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + return input_tensor diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py new file mode 100644 index 0000000000..8bcdb4bf90 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -0,0 +1,148 @@ +import torch +from typing import Optional +from lightllm.common.quantization.no_quant import WeightPack +from lightllm.common.quantization.quantize_method import QuantizationMethod +from .base_impl import FuseMoeBaseImpl + + +class FuseMoeTriton(FuseMoeBaseImpl): + def __init__( + self, + n_routed_experts: int, + num_fused_shared_experts: int, + routed_scaling_factor: float, + quant_method: QuantizationMethod, + redundancy_expert_num: int, + redundancy_expert_ids_tensor: torch.Tensor, + routed_expert_counter_tensor: torch.Tensor, + auto_update_redundancy_expert: bool, + ): + super().__init__( + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + quant_method=quant_method, + redundancy_expert_num=redundancy_expert_num, + redundancy_expert_ids_tensor=redundancy_expert_ids_tensor, + routed_expert_counter_tensor=routed_expert_counter_tensor, + auto_update_redundancy_expert=auto_update_redundancy_expert, + ) + + def create_workspace(self): + return None + + def _select_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + correction_bias: Optional[torch.Tensor], + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + scoring_func: str, + ): + """Select experts and return topk weights and ids.""" + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts, + end=self.n_routed_experts + self.num_fused_shared_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + return topk_weights, topk_ids + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ): + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn + + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_experts + + fused_experts( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + return input_tensor + + def __call__( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ): + topk_weights, topk_ids = self._select_experts( + input_tensor=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + output = self._fused_experts( + input_tensor=input_tensor, + w13=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=router_logits, + is_prefill=is_prefill, + ) + return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py deleted file mode 100644 index 7dc5b5fdcc..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ /dev/null @@ -1,540 +0,0 @@ -import os -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any -from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id -from .base_weight import BaseWeight -from lightllm.common.fused_moe.grouped_fused_moe_ep import ( - fused_experts_impl, - masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, -) -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.distributed import dist_group_manager -from lightllm.common.fused_moe.topk_select import select_experts -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank -from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( - per_token_group_quant_fp8, - tma_align_input_scale, -) -from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair -from lightllm.utils.log_utils import init_logger -from lightllm.common.triton_utils.autotuner import Autotuner - - -logger = init_logger(__name__) - - -class FusedMoeWeightEP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg=None, - ) -> None: - super().__init__() - - self.layer_num = layer_num - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - self.block_size = block_size - - self.weight_prefix = weight_prefix - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - self.e_score_correction_bias_name = e_score_correction_bias_name - self.n_routed_experts = n_routed_experts - self.data_type_ = data_type - - global_world_size = get_global_world_size() - self.global_rank_ = get_global_rank() - self.redundancy_expert_num = get_redundancy_expert_num() - self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num) - logger.info( - f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}" - ) - self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") - self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") - self.total_expert_num_contain_redundancy = ( - self.n_routed_experts + self.redundancy_expert_num * global_world_size - ) - assert self.n_routed_experts % global_world_size == 0 - self.ep_n_routed_experts = self.n_routed_experts // global_world_size - ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - self.experts_up_projs = [None] * ep_load_expert_num - self.experts_gate_projs = [None] * ep_load_expert_num - self.experts_up_proj_scales = [None] * ep_load_expert_num - self.experts_gate_proj_scales = [None] * ep_load_expert_num - self.e_score_correction_bias = None - self.w2_list = [None] * ep_load_expert_num - self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.use_fp8_w8a8 = self.quant_method is not None - network_config["n_group"] = network_config.get("n_group", 0) - self.num_experts_per_tok = network_config["num_experts_per_tok"] - self.use_grouped_topk = network_config["n_group"] > 0 - self.norm_topk_prob = network_config["norm_topk_prob"] - self.n_group = network_config["n_group"] - network_config["topk_group"] = network_config.get("topk_group", 0) - self.topk_group = network_config["topk_group"] - network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 1.0) - self.routed_scaling_factor = network_config["routed_scaling_factor"] - - self.lock = threading.Lock() - # init buffer - - # auto update redundancy expert vars - self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert - - def experts( - self, - input_tensor, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - is_prefill, - ): - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_ids, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - return fused_experts_impl( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_idx=topk_ids.to(torch.long), - num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, - is_prefill=is_prefill, - use_fp8_w8a8=self.use_fp8_w8a8, - use_fp8_all2all=self.use_fp8_w8a8, - use_int8_w8a16=False, # default to False - w1_scale=w1_scale, - w2_scale=w2_scale, - previous_event=None, # for overlap - ) - - def low_latency_dispatch( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ): - - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=self.use_grouped_topk, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - topk_group=self.topk_group, - num_expert_group=self.n_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_idx, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - - topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, - use_fp8=self.use_fp8_w8a8, - async_finish=False, - return_recv_hook=True, - ) - return recv_x, masked_m, topk_idx, topk_weights, handle, hook - - def select_experts_and_quant_input( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ): - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=self.use_grouped_topk, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - topk_group=self.topk_group, - num_expert_group=self.n_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_idx, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - M, K = hidden_states.shape - w1, w1_scale = self.w1 - block_size_k = 0 - if w1.ndim == 3: - block_size_k = w1.shape[2] // w1_scale.shape[2] - assert block_size_k == 128, "block_size_k must be 128" - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) - - def dispatch( - self, - qinput_tensor: Tuple[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - overlap_event: Optional[Any] = None, - ): - buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( - qinput_tensor, - topk_idx=topk_idx, - topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, - expert_alignment=128, - ) - - def hook(): - event.current_stream_wait() - - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook - - def masked_group_gemm( - self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int - ): - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - return masked_group_gemm(recv_x, masked_m, dtype, w1, w1_scale, w2, w2_scale, expected_m=expected_m) - - def prefilled_group_gemm( - self, - num_recv_tokens_per_expert_list, - recv_x: Tuple[torch.Tensor], - recv_topk_idx: torch.Tensor, - recv_topk_weights: torch.Tensor, - hidden_dtype=torch.bfloat16, - ): - device = recv_x[0].device - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - _, K = recv_x[0].shape - _, N, _ = w1.shape - # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. - # gather_out shape [recive_num_tokens, hidden] - gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype) - if all_tokens > 0: - input_tensor = [ - torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype), - torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32), - ] - # when m_indices is filled ok. - # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] - # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] - # ... - m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) - # output_index shape [recive_num_tokens, topk_num] - # output_index use to show the token index in input_tensor - output_index = torch.empty_like(recv_topk_idx) - - num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" - ).cuda(non_blocking=True) - - expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) - - ep_scatter( - recv_x[0], - recv_x[1], - recv_topk_idx, - num_recv_tokens_per_expert, - expert_start_loc, - input_tensor[0], - input_tensor[1], - m_indices, - output_index, - ) - input_tensor[1] = tma_align_input_scale(input_tensor[1]) - # groupgemm (contiguous layout) - gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) - - # silu_and_mul_fwd + qaunt - # TODO fused kernel - silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) - - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) - qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( - silu_out, self.block_size, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True - ) - - # groupgemm (contiguous layout) - gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) - # gather and local reduce - ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) - else: - ######################################## warning ################################################## - # here is used to match autotune feature, make moe model run same triton kernel in different rank. - # in some special case, one rank will recv 0 token, so add a token to make it run triton kernel. - if Autotuner.is_autotune_warmup(): - _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) - _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) - _gemm_out_a, _silu_out = None, None - - return gather_out - - def low_latency_combine( - self, - gemm_out_b: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - handle: Any, - ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( - gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True - ) - return combined_x, hook - - def combine( - self, - gemm_out_b: torch.Tensor, - handle: Any, - overlap_event: Optional[Any] = None, - ): - # normal combine - combined_x, _, event = dist_group_manager.ep_buffer.combine( - gemm_out_b, - handle, - topk_weights=None, - async_finish=True, - previous_event=overlap_event, - allocate_on_comm_stream=True, - ) - - def hook(): - event.current_stream_wait() - - return combined_x, hook - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - n_expert_ep = self.ep_n_routed_experts - # tp to ep here - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - - for i_experts_ep in range(n_expert_ep): - i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[i_experts_ep] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[i_experts_ep] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[i_experts_ep] = weights[w2_weight] - - # Load weight parameters for redundant experts - for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[n_expert_ep + i] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[n_expert_ep + i] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[n_expert_ep + i] = weights[w2_weight] - - if self.quantized_weight: - self._load_weight_scale(weights) - self._fuse() - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - n_expert_ep = self.ep_n_routed_experts - for i_experts_ep in range(n_expert_ep): - i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts_ep] = weights[w1_scale] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts_ep] = weights[w3_scale] - - if w2_scale in weights: - self.w2_scale_list[i_experts_ep] = weights[w2_scale] - - # Load scale parameters for redundant experts - for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[n_expert_ep + i] = weights[w1_scale] - if w3_scale in weights: - self.experts_up_proj_scales[n_expert_ep + i] = weights[w3_scale] - if w2_scale in weights: - self.w2_scale_list[n_expert_ep + i] = weights[w2_scale] - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py deleted file mode 100644 index 9295fa96ae..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ /dev/null @@ -1,669 +0,0 @@ -import os -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any, Union -from .base_weight import BaseWeight -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id -from lightllm.common.quantization import Quantcfg - - -def create_tp_moe_wegiht_obj( - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, -) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: - quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is not None and quant_method.method_name == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - else: - return FusedMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - - -class FusedMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - use_fp8_w8a8 = self.quant_method is not None - - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts - - fused_experts( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - return - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.n_routed_experts - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.n_routed_experts): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.n_routed_experts - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if self.quant_method is not None: - self._load_weight_scale(weights) - self._fuse() - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - :, - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - ] - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None - - -class FusedAWQMARLINMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix - self.quant_method.is_moe = True - hf_quantization_config = network_config.get("quantization_config", None) - self.num_bits = hf_quantization_config.get("bits", 4) - self.group_size = hf_quantization_config.get("group_size", 128) - self.pack_factor = 32 // self.num_bits - self.has_processed_weight = False - assert self.quant_method.method_name == "awq_marlin" - - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_up_proj_zero_points = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_zero_points = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.w2_zero_point_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None, None] # weight, weight_scale, zero_point - self.w2 = [None, None, None] # weight, weight_scale, zero_point - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale, w1_zero_point = self.w1 - w2, w2_scale, w2_zero_point = self.w2 - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - - fused_marlin_moe( - input_tensor, - w1, - w2, - None, - None, - w1_scale, - w2_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=self.quant_method.vllm_quant_type.id, - apply_router_weight_on_input=False, - global_num_experts=-1, - expert_map=None, - w1_zeros=w1_zero_point, - w2_zeros=w2_zero_point, - workspace=self.workspace, - inplace=True, - ) - - return - - def _fuse(self): - self._fuse_weight() - self._fuse_weight_scale() - self._fuse_weight_zero_point() - - def _fuse_weight(self): - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape - up_in_dim, up_out_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - - w1 = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] - w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape - dtype = self.experts_gate_proj_scales[0].dtype - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_scale = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale).to(self.data_type_) - self.w2[1] = self._cuda(w2_scale).to(self.data_type_) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def _fuse_weight_zero_point(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_zero_points") - and None not in self.experts_up_proj_zero_points - and None not in self.experts_gate_proj_zero_points - and None not in self.w2_zero_point_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_zero_point = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] - w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] - inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] - w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( - len(self.w2_zero_point_list), inter_shape, hidden_size - ) - self.w1[2] = self._cuda(w1_zero_point) - self.w2[2] = self._cuda(w2_zero_point) - delattr(self, "w2_zero_point_list") - delattr(self, "experts_up_proj_zero_points") - delattr(self, "experts_gate_proj_zero_points") - - def load_hf_weights(self, weights): - self._load_weight(weights) - self._load_weight_scale(weights) - self._load_weight_zero_point(weights) - self._fuse() - self._process_weight_after_loading() - - def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: - # awq quantization weight shape: in x out - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - split_inter_size = self.split_inter_size * self.pack_factor - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" - w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" - w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" - if w1_zero_point in weights: - self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w3_zero_point in weights: - self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w2_zero_point in weights: - self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _process_weight_after_loading(self): - with self.lock: - if None in self.w1 or None in self.w2 or self.has_processed_weight: - return - self.has_processed_weight = True - from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - - assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" - - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, - moe_awq_to_marlin_zero_points, - marlin_make_workspace_new, - ) - - num_experts = self.n_routed_experts - device = self.w1[0].device - - self.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w1[0] = vllm_ops.awq_marlin_moe_repack( - self.w1[0], - self.w13_g_idx_sort_indices, - size_k=self.w1[0].shape[1], - size_n=self.w1[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[0] = vllm_ops.awq_marlin_moe_repack( - self.w2[0], - self.w2_g_idx_sort_indices, - size_k=self.w2[0].shape[1], - size_n=self.w2[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - # Why does this take the intermediate size for size_k? - self.w1[1] = marlin_moe_permute_scales( - s=self.w1[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w1[1].shape[2], - group_size=self.group_size, - ) - - self.w2[1] = marlin_moe_permute_scales( - s=self.w2[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w2[1].shape[2], - group_size=self.group_size, - ) - - self.w1[2] = moe_awq_to_marlin_zero_points( - self.w1[2], - size_k=self.w1[2].shape[1], - size_n=self.w1[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[2] = moe_awq_to_marlin_zero_points( - self.w2[2], - size_k=self.w2[2].shape[1], - size_n=self.w2[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.workspace = marlin_make_workspace_new(device, 4) - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.cuda(device_id) - return cpu_tensor.cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 63605b1774..e9ae4f30ab 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,10 +1,5 @@ from .mm_weight import ( - MMWeightPack, MMWeightTpl, ) -from .mm_factory import ( - MMWeight, - ROWMMWeight, - ROWBMMWeight, - COLMMWeight, -) +from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight +from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 281f30f022..1a02e00d0f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,19 +1,20 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size +from .mm_slicer import get_col_slice_mixin -class StandardCOLMMWeight(MMWeightTpl): +class COLMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -21,82 +22,19 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + in_dim = self._get_tp_dim(in_dim) super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, ) - self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQCOLMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINCOLMMWeight(AWQCOLMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) - - -COLMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, - "awq": AWQCOLMMWeight, - "awq_marlin": AWQMARLINCOLMMWeight, -} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py deleted file mode 100644 index 464de84413..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from lightllm.common.quantization import Quantcfg -from lightllm.common.quantization.quantize_method import QuantizationMethod -from typing import Type, Union, Dict -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, - BMMWeightTpl, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( - StandardROWMMWeight, - UnquantizedROWBMMWeight, - ROWMM_WEIGHT_CLS_MAP, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( - StandardCOLMMWeight, - COLMM_WEIGHT_CLS_MAP, -) - - -class MMWeight: - def __new__(cls, **kwargs): - """ - weight_names, - data_type, - bias_names, - quant_cfg, - layer_num, - name, - tp_rank, - tp_world_size, - ... - 该类主要是通过重载 __new__ 为对应的mm权重绑定量化方法,其他参数都是透传。 - """ - - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - # quantized_weight 本身是用来标识权重本身在文件中是否是以量化后的形式存储, - # 现在不再使用该参数,是否量化由后续的加载过程自动识别。 - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method: QuantizationMethod = quant_cfg.get_quant_method(layer_num_, name) - if quant_method is None: - return None, False - quant_method.hf_quantization_config = quant_cfg.hf_quantization_config - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod) -> Type[Union[MMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") - - -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardROWMMWeight - - return ROWMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardROWMMWeight, - ) - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return UnquantizedROWBMMWeight - else: - # TODO: Implement more quantization weight - raise NotImplementedError("ROWBMMWeight is not implemented") - - -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardCOLMMWeight - return COLMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardCOLMMWeight, - ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index e3ef5b0ea3..ddbf98a866 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import Optional, Tuple from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size @@ -7,9 +7,12 @@ class SliceMixinBase(ABC): """切片操作的Mixin基类""" - def __init__(self, tp_rank: int = None, tp_world_size: int = None): + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + # this param is used to slice the weight when tp_world_size_ is divisible by the kv_head_num + # for example, if tp_world_size_ is 8 and kv_head_num is 4, then repeat_times_ is 2 + self.repeat_times_ = repeat_times @abstractmethod def _slice_weight(self, weight: torch.Tensor): @@ -19,10 +22,16 @@ def _slice_weight(self, weight: torch.Tensor): def _slice_bias(self, bias): pass + def _get_slice_start_end(self, size: int) -> Tuple[int, int]: + tp_size = size * self.repeat_times_ // self.tp_world_size_ + start = tp_size * (self.tp_rank_ // self.repeat_times_) + end = start + tp_size + return start, end + class SliceMixinTpl(SliceMixinBase): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError("slice_weight must implement this method") @@ -40,95 +49,117 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 # 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 class RowSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[0]) + return weight[start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + bias.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {bias.shape[0] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(bias.shape[0]) + return bias[start:end] # 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 # 后续按需要,扩展per-tensor、per-channel的量化方式。 class QuantizedRowSliceMixin(RowSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[0] % self.tp_world_size_ == 0 ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[scale_start:scale_end] + start, end = self._get_slice_start_end(weight_scale.shape[0]) + return weight_scale[start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( weight_zero_point.shape[0] % self.tp_world_size_ == 0 ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - tp_size = weight_zero_point.shape[0] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[zero_point_start:zero_point_end] + start, end = self._get_slice_start_end(weight_zero_point.shape[0]) + return weight_zero_point[start:end] class ColSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: - assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" - tp_size = weight.shape[1] // self.tp_world_size_ - return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[1]) + return weight[:, start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias / self.tp_world_size_ + return bias / self.tp_world_size_ * self.repeat_times_ class QuantizedColSliceMixin(ColSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[1] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[:, scale_start:scale_end] + weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[1]) + return weight_scale[:, start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1]} % {self.tp_world_size_}" - tp_size = weight_zero_point.shape[1] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[:, zero_point_start:zero_point_end] + weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[1]) + return weight_zero_point[:, start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + bias.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {bias.shape[0] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(bias.shape[0]) + return bias[start:end] class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias / self.tp_world_size_ + return bias / self.tp_world_size_ * self.repeat_times_ + + +def get_row_slice_mixin( + quant_method_name: str, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1 +) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) + elif quant_method_name == "none": + return RowSliceMixin(tp_rank, tp_world_size, repeat_times) + else: + return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) + + +def get_col_slice_mixin( + quant_method_name: str, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1 +) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) + elif quant_method_name == "none": + return ColSliceMixin(tp_rank, tp_world_size, repeat_times) + else: + return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 7391454da0..3630bc2c00 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -5,9 +5,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.no_quant import NoQuantization from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger from .mm_slicer import SliceMixinTpl @@ -15,53 +16,11 @@ logger = init_logger(__name__) -@dataclass -class MMWeightPack: - weight: Optional[torch.Tensor] = None - bias: Optional[torch.Tensor] = None - weight_scale: Optional[torch.Tensor] = None - weight_zero_point: Optional[torch.Tensor] = None - - has_bias: bool = False - has_weight_scale: bool = False - has_weight_zero_point: bool = False - - def is_ready(self) -> bool: - return ( - self.weight is not None - and (not self.has_bias or (self.has_bias and self.bias is not None)) - and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) - and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) - ) - - def ready_for_fused_merge(self) -> bool: - """ - 判断权重是否满足可以和其他权重进行融合cat的条件,因为可能权重是量化和非量化后的权重,所以复杂一些。 - """ - weight_ready = self.weight is not None and self.weight.dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, - ] - bias_ready = (self.has_bias and self.bias is not None) or (not self.has_bias) - if weight_ready and bias_ready: - return True - else: - return self.is_ready() - - def is_load_finished(self): - return ( - (self.is_ready() and self.weight.is_cuda) - and ((self.has_bias and self.bias.is_cuda) or (not self.has_bias)) - and ((self.has_weight_scale and self.weight_scale.is_cuda) or (not self.has_weight_scale)) - and ((self.has_weight_zero_point and self.weight_zero_point.is_cuda) or (not self.has_weight_zero_point)) - ) - - class MMWeightTpl(BaseWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], bias_names: Optional[Union[str, List[str]]], data_type: torch.dtype, @@ -70,7 +29,11 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__(tp_rank, tp_world_size, data_type) - self.lock = threading.Lock() + + self.in_dim = in_dim + if isinstance(out_dims, int): + out_dims = [out_dims] + self.out_dims = out_dims if isinstance(weight_names, str): weight_names = [weight_names] @@ -82,173 +45,76 @@ def __init__( if bias_names[0] is None: bias_names = None - if quant_method is not None: - has_weight_scale = quant_method.has_weight_scale - has_weight_zero_point = quant_method.has_weight_zero_point - else: - has_weight_scale = False - has_weight_zero_point = False - # 同时存在 weight_names 和 quanted_weight_names 是为了兼容在线和离线两种加载方案 self.weight_names = weight_names - self.bias_names = bias_names - has_bias = self.bias_names is not None - - self.gen_weight_quant_param_names(quant_method=quant_method) - self.quant_method = quant_method - self.sub_child_mm_params: List[MMWeightPack] = [ - MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) - for _ in range(len(weight_names)) - ] - self.mm_param: MMWeightPack = MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) + self.quant_method: QuantizationMethod = NoQuantization() if quant_method is None else quant_method self.param_slicer: SliceMixinTpl = None - - self.weight_fused_dim = 0 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 0 - - self.load_finished: bool = False + self._create_weight() + self.gen_weight_quant_param_names() def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: - if self.quant_method is not None: - return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger - ) - if out is None: - shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) - dtype = input_tensor.dtype - device = input_tensor.device - if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: - return torch.mm(input_tensor, self.mm_param.weight, out=out) - return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) - - def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]): - if quant_method is None: - self.quanted_weight_names = None - self.weight_zero_point_names = None - self.weight_scale_names = None - return - - quanted_weight_names = [] - weight_scale_names = [] - weight_zero_point_names = [] - - for weight_name in self.weight_names: - if quant_method.weight_scale_suffix is not None: - weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) - weight_scale_names.append(weight_scale_name) - if quant_method.weight_zero_point_suffix is not None: - weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) - weight_zero_point_names.append(weight_zero_point_name) - if quant_method.weight_suffix is not None: - weight_name = weight_name.replace("weight", quant_method.weight_suffix) - quanted_weight_names.append(weight_name) - - if len(quanted_weight_names) != 0: - self.quanted_weight_names = quanted_weight_names - else: - self.quanted_weight_names = None - - if len(weight_scale_names) != 0: - self.weight_scale_names = weight_scale_names - else: - self.weight_scale_names = None + return self.quant_method.apply( + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + ) - if len(weight_zero_point_names) != 0: - self.weight_zero_point_names = weight_zero_point_names - else: - self.weight_zero_point_names = None + def gen_weight_quant_param_names(self): + self.quanted_weight_names = [None] * len(self.weight_names) + self.weight_zero_point_names = [None] * len(self.weight_names) + self.weight_scale_names = [None] * len(self.weight_names) + + for sub_child_index, weight_name in enumerate(self.weight_names): + if self.quant_method.weight_scale_suffix is not None: + weight_scale_name = weight_name.replace("weight", self.quant_method.weight_scale_suffix) + self.weight_scale_names[sub_child_index] = weight_scale_name + if self.quant_method.weight_zero_point_suffix is not None: + weight_zero_point_name = weight_name.replace("weight", self.quant_method.weight_zero_point_suffix) + self.weight_zero_point_names[sub_child_index] = weight_zero_point_name + if self.quant_method.weight_suffix is not None: + weight_name = weight_name.replace("weight", self.quant_method.weight_suffix) + self.quanted_weight_names[sub_child_index] = weight_name return def load_hf_weights(self, weights): - if self.mm_param.is_load_finished(): - return for sub_child_index, param_name in enumerate(self.weight_names): self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - - if self.quanted_weight_names is not None: - for sub_child_index, param_name in enumerate(self.quanted_weight_names): - self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - + for sub_child_index, param_name in enumerate(self.weight_scale_names): + self._load_weight_scale(param_name=param_name, weights=weights, sub_child_index=sub_child_index) + for sub_child_index, param_name in enumerate(self.weight_zero_point_names): + self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) if self.bias_names is not None: for sub_child_index, param_name in enumerate(self.bias_names): self._load_bias(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - if self.weight_scale_names is not None: - for sub_child_index, param_name in enumerate(self.weight_scale_names): - self._load_weight_scale(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - if self.weight_zero_point_names is not None: - for sub_child_index, param_name in enumerate(self.weight_zero_point_names): - self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - - with self.lock: - # 如果需要fused的请求,全部ok了以后进行merge操作。, all([]) 竟然返回是True, 需要len(self.sub_child_mm_params) > 0 的额外判断。 - if len(self.sub_child_mm_params) > 0 and all(e.ready_for_fused_merge() for e in self.sub_child_mm_params): - self._fuse_weights() - self.sub_child_mm_params.clear() - - # 在线量化操作 - if ( - self.quant_method is not None - and self.mm_param.weight is not None - and self.quant_method.weight_need_quanted(self.mm_param.weight) - and self.load_finished is False - ): - logger.info(f"online quant weight names: {self.weight_names}") - quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( - self.mm_param.weight.cuda(get_current_device_id()) - ) - self.mm_param.weight = quantized_weight - self.mm_param.weight_scale = weight_scale - self.mm_param.weight_zero_point = weight_zero_point - # repack 操作 - if ( - self.quant_method is not None - and self.mm_param.is_ready() - and self.quant_method.params_need_repack() - and self.load_finished is False - ): - ( - self.mm_param.weight, - self.mm_param.weight_scale, - self.mm_param.weight_zero_point, - ) = self.quant_method.params_repack( - weight=self.mm_param.weight, - weight_scale=self.mm_param.weight_scale, - weight_zero_point=self.mm_param.weight_zero_point, - dtype_type=self.data_type_, - ) - - if self.mm_param.is_ready() and self.load_finished is False: - self._to_gpu_device() - self.load_finished = True - - def verify_load(self) -> bool: - return self.mm_param.is_ready() + def _create_weight(self): + self.bias = None + if self.bias_names is not None: + self.bias = torch.empty(sum(self.out_dims), dtype=self.data_type_).cuda(get_current_device_id()) + # bias_list shares storage with bias for each output shard + self.bias_list = torch.split(self.bias, self.out_dims, dim=0) + for sub_bias in self.bias_list: + sub_bias.load_ok = False + self.mm_param: WeightPack = None + self.mm_param_list: List[WeightPack] = None + self.mm_param, self.mm_param_list = self.quant_method.create_weight( + in_dim=self.in_dim, out_dims=self.out_dims, dtype=self.data_type_, device_id=get_current_device_id() + ) + return # 执行顺序 def _load_weight( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int ) -> None: + quanted_param_name = self.quanted_weight_names[sub_child_index] + # if the original weight is quantized, use the quantized_param_name. + if quanted_param_name in weights: + param_name = quanted_param_name if param_name in weights: weight = self.param_slicer._slice_weight(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight = weight + self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) return def _load_bias( @@ -256,7 +122,8 @@ def _load_bias( ) -> None: if param_name in weights: bias = self.param_slicer._slice_bias(weights[param_name]) - self.sub_child_mm_params[sub_child_index].bias = bias + self.bias_list[sub_child_index].copy_(bias) + self.bias_list[sub_child_index].load_ok = True return def _load_weight_scale( @@ -264,7 +131,7 @@ def _load_weight_scale( ) -> None: if param_name in weights: weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_scale = weight_scale + self.quant_method.load_weight_scale(weight_scale, self.mm_param_list[sub_child_index]) return def _load_weight_zero_point( @@ -272,102 +139,73 @@ def _load_weight_zero_point( ) -> None: if param_name in weights: weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_zero_point = weight_zero_point + self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param_list[sub_child_index]) return - # weight merge - def _fuse_weights(self) -> None: - need_merge = len(self.sub_child_mm_params) > 1 - if self.mm_param.weight is None and all(p.weight is not None for p in self.sub_child_mm_params): - if need_merge: - weight = torch.cat([p.weight for p in self.sub_child_mm_params], dim=self.weight_fused_dim) - else: - weight = self.sub_child_mm_params[0].weight - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight = None - - self.mm_param.weight = weight - - if ( - self.mm_param.has_bias - and self.mm_param.bias is None - and all(p.bias is not None for p in self.sub_child_mm_params) - ): - if need_merge: - bias = torch.cat([p.bias for p in self.sub_child_mm_params], dim=self.bias_fused_dim) - else: - bias = self.sub_child_mm_params[0].bias + def verify_load(self): + mm_param_load_ok = all(all(_mm_param.load_ok) for _mm_param in self.mm_param_list) + bias_load_ok = True if self.bias is None else all(sub_bias.load_ok for sub_bias in self.bias_list) + if not (mm_param_load_ok and bias_load_ok): + logger.warning(f"mm_param_load_ok: {self.mm_param_list[0].load_ok}") + return mm_param_load_ok and bias_load_ok - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.bias = None - - self.mm_param.bias = bias - - if self.mm_param.weight_scale is None and all(p.weight_scale is not None for p in self.sub_child_mm_params): - if need_merge: - weight_scale = torch.cat( - [p.weight_scale for p in self.sub_child_mm_params], dim=self.weight_scale_and_zero_point_fused_dim - ) - else: - weight_scale = self.sub_child_mm_params[0].weight_scale + def _get_tp_dim(self, dim: int) -> int: + assert ( + dim % self.tp_world_size_ == 0 + ), f"dim must be divisible by tp_world_size_, but found: {dim} % {self.tp_world_size_}" + return dim // self.tp_world_size_ - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_scale = None - self.mm_param.weight_scale = weight_scale - - if self.mm_param.weight_zero_point is None and all( - p.weight_zero_point is not None for p in self.sub_child_mm_params - ): - if need_merge: - weight_zero_point = torch.cat( - [p.weight_zero_point for p in self.sub_child_mm_params], - dim=self.weight_scale_and_zero_point_fused_dim, - ) - else: - weight_zero_point = self.sub_child_mm_params[0].weight_zero_point - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_zero_point = None - - self.mm_param.weight_zero_point = weight_zero_point +class BMMWeightTpl(BaseWeightTpl): + def __init__( + self, + dim0: int, + dim1: int, + dim2: int, + weight_names: Union[str, List[str]], + data_type: torch.dtype, + bias_names: Optional[Union[str, List[str]]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(tp_rank, tp_world_size, data_type) + if isinstance(weight_names, str): + weight_names = [weight_names] + self.weight_names = weight_names + self.bias_names = bias_names + assert bias_names is None, "bmm not support bias" + if isinstance(bias_names, list): + assert all(bias_name is None for bias_name in bias_names), "bmm not support bias" + assert quant_method is None, "bmm not support quantized weight" + self.quant_method = quant_method + self.dim0 = dim0 + self.dim1 = dim1 + self.dim2 = dim2 + self._create_weight() return - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.mm_param.weight = ( - self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) - ) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) + def _create_weight(self): + self.weight = torch.empty(self.dim0, self.dim1, self.dim2, dtype=self.data_type_).cuda(get_current_device_id()) + self.weight.load_ok = False return + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + for weight_name in self.weight_names: + if weight_name in weights: + weight = self.param_slicer._slice_weight(weights[weight_name]) + self.weight.copy_(weight) + self.weight.load_ok = True + return -class BMMWeightTpl(MMWeightTpl): - def mm( - self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True - ) -> torch.Tensor: - raise RuntimeError("use bmm not mm") + def verify_load(self): + return self.weight.load_ok def bmm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: # 目前 bmm 不支持量化运算操作 - fpweight = self.mm_param.weight + fpweight = self.weight if out is None: shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) dtype = input_tensor.dtype @@ -376,90 +214,4 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: - return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # bmm 不需要 transpose 操作 - self.mm_param.weight = self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class DeepGemmFP8W8A8B128MMWeight(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()).transpose(0, 1) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - assert self.mm_param.has_weight_zero_point is False - - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class AWQMMWeightTpl(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - bias_names: Optional[Union[str, List[str]]] = None, - data_type: torch.dtype = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.weight_fused_dim = 1 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 1 - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return + return torch.bmm(input_tensor, fpweight, out=out) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 0eebdc74d2..30a699bb68 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,20 +1,18 @@ import torch -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, - BMMWeightTpl, -) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl, BMMWeightTpl from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size +from .mm_slicer import get_row_slice_mixin -class StandardROWMMWeight(MMWeightTpl): +class ROWMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -22,42 +20,30 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + out_dims = [self._get_tp_dim(out_dim) for out_dim in out_dims] super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, bias_names=bias_names, data_type=data_type, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) - self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - return -class UnquantizedROWBMMWeight(BMMWeightTpl): +class KVROWNMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + kv_head_num: int, + head_dim: int, weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -65,42 +51,53 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.repeat_times = 1 + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQROWMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + out_dims = [kv_hidden_size, kv_hidden_size] super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.repeat_times, ) - self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) + def _get_tp_padded_head_num(self, head_num: int): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + elif self.tp_world_size_ % head_num == 0: + self.repeat_times = self.tp_world_size_ // head_num + return self.repeat_times * head_num // self.tp_world_size_ + else: + raise ValueError( + f"head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by head_num, " + f"but found: {head_num} % {self.tp_world_size_}" + ) -class AWQMARLINROWMMWeight(AWQROWMMWeight): +class ROWBMMWeight(BMMWeightTpl): def __init__( self, + dim0: int, + dim1: int, + dim2: int, weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -108,18 +105,23 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + assert ( + dim0 % self.tp_world_size_ == 0 + ), f"dim0 of bmm must be divisible by tp_world_size_, but found: {dim0} % {self.tp_world_size_}" + dim0 = dim0 // self.tp_world_size_ super().__init__( + dim0=dim0, + dim1=dim1, + dim2=dim2, weight_names=weight_names, - data_type=data_type, bias_names=bias_names, + data_type=data_type, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.param_slicer = get_row_slice_mixin( + quant_method_name="none", tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) - - -ROWMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, - "awq": AWQROWMMWeight, - "awq_marlin": AWQMARLINROWMMWeight, -} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 5a595bff61..c922bffc45 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,143 +1,241 @@ import torch -from typing import Optional +from typing import Optional, Dict from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward -from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_forward +from .platform_op import PlatformAwareOp -logger = init_logger(__name__) - -class _NormWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__() +class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype): + super().__init__(tp_rank=0, tp_world_size=1) + self.dim = dim self.weight_name = weight_name - self.bias_name = bias_name self.data_type_ = data_type - self.weight: torch.Tensor = None - self.bias: Optional[torch.Tensor] = None + self._create_weight() + + def _create_weight(self): + self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight.load_ok = True def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.bias_name is not None: - load_ok = load_ok and self.bias is not None - return load_ok - - def rmsnorm_forward( + return self.weight.load_ok + + def _native_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - assert input.ndim in [2, 3] and self.weight.ndim == 1 - assert self.bias is None + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + x = input.to(torch.float32) + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * self.weight).to(self.data_type_) + if out is not None: + out.copy_(x) + return out + return x + + def _triton_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim} != 2 or weight.ndim: {self.weight.ndim} != 1" if out is None: out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) - def layernorm_forward( + def _cuda_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - assert input.ndim == 2 and self.weight.ndim == 1 - assert self.bias is not None - - _tout = layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps) - if out is None: - return _tout - else: - out.copy_(_tout) - return out + # only triton implementation is supported for rmsnorm on cuda platform + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def _musa_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) -class NoTpNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) - self.tp_world_size_ = 1 - self.tp_rank_ = 0 + def __call__( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) - def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights and self.bias is None: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) +class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__(tp_rank=0, tp_world_size=1) + self.dim = dim + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self._create_weight() + + def _create_weight(self): + self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight.load_ok = True + if self.bias_name in weights: + self.bias.copy_(weights[self.bias_name]) + self.bias.load_ok = True -class NoTpGEMMANormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - assert self.bias_name is None - self.tp_world_size_ = 1 - self.tp_rank_ = 0 + def verify_load(self): + return self.weight.load_ok and self.bias.load_ok - def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) + def _native_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + x = torch.nn.functional.layer_norm( + input, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=eps + ) + if out is not None: + out.copy_(x.to(self.data_type_)) + return out + return x.to(self.data_type_) + def _triton_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # assert input.ndim == 2 and self.weight.ndim == 1 + if out is None: + return layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps) + else: + out.copy_(layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps)) + return out -class TpVitPadNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, head_num: int, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - self.head_num = head_num + def _cuda_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # only triton implementation is supported for layernorm on cuda platform + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) - def _pad_tensor_param(self, weight: torch.Tensor): - assert weight.ndim == 1 - hidden_size = weight.shape[0] - head_dim = hidden_size // self.head_num - assert hidden_size % self.head_num == 0 + def _musa_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) - if self.head_num % self.tp_world_size_ == 0: - return weight + def __call__( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + + +class TpRMSNormWeight(RMSNormWeight): + def __init__(self, head_num, head_dim, weight_name: str, data_type: torch.dtype): + padded_head_num = self._get_tp_padded_head_num(head_num) + dim = padded_head_num * head_dim + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type) + # 重新初始化 tp rank 的信息, load hf weights 的时候会用到 + self.tp_rank_ = get_current_rank_in_dp() + self.tp_world_size_ = get_dp_world_size() + self.repeat_times_ = 1 + + def _get_tp_padded_head_num(self, head_num: int): + """ + Get the padded dimension for the weight. + 1. If head_num is divisible by tp_world_size_, return head_num. + 2. If head_num is greater than tp_world_size_, return: + (head_num + tp_world_size_ - 1) // tp_world_size_ * tp_world_size_ + 3. If head_num is less than tp_world_size_, assert tp_world_size_ is + divisible by head_num, and return head_num. + """ + self.tp_world_size_ = get_dp_world_size() + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + + if head_num > self.tp_world_size_: + return (head_num + self.tp_world_size_ - 1) // self.tp_world_size_ * self.tp_world_size_ else: - logger.warning(f"padding {self.weight_name} weights in TpVitPadNormWeight") - pad_head_num = self.tp_world_size_ - (self.head_num % self.tp_world_size_) - pad_dims = pad_head_num * head_dim - weight = torch.nn.functional.pad(weight, (0, pad_dims), mode="constant", value=0.0) - return weight + assert ( + self.tp_world_size_ % head_num == 0 + ), f"tp_world_size_ must be divisible by head_num, but found: {self.tp_world_size_} % {head_num}" + self.repeat_times_ = self.tp_world_size_ // head_num + return head_num * self.repeat_times_ // self.tp_world_size_ def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: + if self.weight_name in weights: t_weight = weights[self.weight_name] - t_weight = self._pad_tensor_param(t_weight) - new_hidden_size = t_weight.shape[0] - split_n_embed = new_hidden_size // self.tp_world_size_ - assert new_hidden_size % self.tp_world_size_ == 0 + hidden_size = t_weight.shape[0] + split_hidden_size = hidden_size // self.tp_world_size_ - start = split_n_embed * self.tp_rank_ - end = split_n_embed * (self.tp_rank_ + 1) + start = split_hidden_size * self.tp_rank_ // self.repeat_times_ + end = min(split_hidden_size * (self.tp_rank_ + 1) // self.repeat_times_, hidden_size) - self.weight = t_weight[start:end].to(self.data_type_).cuda(get_current_device_id()) + self.weight[: end - start].copy_(t_weight[start:end].to(self.data_type_)) + # the padding part is zero + self.weight[end - start :].zero_() + self.weight.load_ok = True - if self.bias_name in weights and self.bias is None: - t_bias = weights[self.bias_name] - t_bias = self._pad_tensor_param(t_bias) - new_hidden_size = t_bias.shape[0] - split_n_embed = new_hidden_size // self.tp_world_size_ - assert new_hidden_size % self.tp_world_size_ == 0 - start = split_n_embed * self.tp_rank_ - end = split_n_embed * (self.tp_rank_ + 1) +class NoTpGEMMANormWeight(RMSNormWeight): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype): + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type) - self.bias = t_bias[start:end].to(self.data_type_).cuda(get_current_device_id()) + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight += 1 -class TpHeadNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) +class QKRMSNORMWeight(RMSNormWeight): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype): + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type) - def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - t_weight = weights[self.weight_name] - start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) - self.weight: torch.Tensor = ( - t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) - assert self.weight.ndim == 2 - - if self.bias_name in weights and self.bias is None: - t_bias = weights[self.bias_name] - start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_bias) - self.bias: torch.Tensor = ( - t_bias[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) - assert self.bias.ndim == 2 + def _native_forward( + self, + input: torch.Tensor, + eps: float, + ) -> None: + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + head_dim = self.weight.shape[0] + x = input.to(torch.float32) + x = x.view(-1, head_dim) + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * self.weight).to(self.data_type_) + x = x.view(-1, input.shape[-1]) + input.copy_(x) + return + + def _triton_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + return qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps) + + def _cuda_forward( + self, + input: torch.Tensor, + eps: float, + ) -> None: + self._triton_forward(input=input, eps=eps) + return + + def _musa_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + # musa implementation is supported by musa triton on musa platform + return self._triton_forward(input=input, eps=eps) + + def __call__( + self, + input: torch.Tensor, + eps: float, + ) -> None: + return self._forward(input=input, eps=eps) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py new file mode 100644 index 0000000000..1ba1610fc9 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py @@ -0,0 +1,71 @@ +import torch +from abc import ABC, abstractmethod +from typing import Optional, Callable, Any +from lightllm.utils.device_utils import get_platform, Platform +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class PlatformAwareOp(ABC): + """ + platform aware op base class, + automatically route to the corresponding implementation method according to the platform. + """ + + def __init__(self): + args = get_env_start_args() + self.platform = get_platform(args.hardware_platform) + self.enable_torch_fallback = args.enable_torch_fallback + self.enable_triton_fallback = args.enable_triton_fallback + self._forward = self._route_forward() + + def _route_forward(self) -> Callable: + + method_name_map = { + Platform.CUDA: "_cuda_forward", + Platform.ASCEND: "_ascend_forward", + Platform.CAMBRICON: "_cambricon_forward", + Platform.MUSA: "_musa_forward", + Platform.ROCM: "_rocm_forward", + Platform.CPU: "_cpu_forward", + } + + method_name = method_name_map.get(self.platform) + if method_name and hasattr(self, method_name): + method = getattr(self, method_name) + if callable(method): + return method + + if self.enable_triton_fallback: + if hasattr(self, "_triton_forward"): + return self._triton_forward + logger.warning( + f"No triton implementation found for {self.__class__.__name__} on {self.platform.name} platform. " + f"Please implement {self.__class__.__name__}_{self.platform.name}_triton_forward method, " + f"or set --enable_torch_fallback to use default implementation." + ) + + if self.enable_torch_fallback: + return self._native_forward + + # if no implementation found, raise error + raise NotImplementedError( + f"No implementation found for {self.__class__.__name__} on {self.platform.name} platform. " + f"Please implement {self.__class__.__name__}_{self.platform.name}_forward method, " + f"or set --enable_torch_fallback to use default implementation." + ) + + @abstractmethod + def _native_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("default forward must implement this method") + + @abstractmethod + def _cuda_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("cuda forward must implement this method") + + # Since Triton may be compatible with all hardware platforms in the future, + # so provide triton implementation as a fallback for all hardware platforms + def _triton_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("triton forward must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 4bc58c76f6..86a887a259 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -4,6 +4,7 @@ from .base_layer_weight import BaseLayerWeight from .meta_weights import BaseWeight, MMWeightTpl from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization import Quantcfg logger = init_logger(__name__) @@ -14,7 +15,7 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg): self.layer_num_ = layer_num self.data_type_ = data_type self.network_config_ = network_config - self.quant_cfg = quant_cfg + self.quant_cfg: Quantcfg = quant_cfg self._parse_config() self._init_weight_names() self._init_weight() @@ -40,3 +41,6 @@ def load_hf_weights(self, weights): attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def get_quant_method(self, name): + return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py deleted file mode 100644 index 143d93b230..0000000000 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py +++ /dev/null @@ -1,649 +0,0 @@ -import time - -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - ], - key=['M', 'N', 'K', 'NO_GROUPS'], -) -@triton.jit -def matmul4_kernel( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N//8) int32 - groupsize is an int specifying the size of groups for scales and zeros. - G is K // groupsize. - Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient. - WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. - WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. - WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. - """ - bits = 4 - infearure_per_bits = 8 - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) - # zeros_ptrs is set up such that it repeats elements along the N axis 8 times - zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,) - # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - # If G == 1, scales and zeros are the same for all K, so we can load them once - if NO_GROUPS: - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 - # Unpack zeros - zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 - # zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16 - zeros = zeros * scales - # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) - # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension - # So this loop is along the infeatures dimension (K) - # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - if not NO_GROUPS: - g_id = k // (groupsize // BLOCK_SIZE_K) - ptr = scales_ptrs + g_id * stride_scales_g - scales = tl.load(ptr) # (BLOCK_SIZE_N,) - ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 - # Unpack zeros - zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 - zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16 - # Now we need to unpack b (which is 4-bit values) into 32-bit values - b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values - b = b * scales[None, :] - zeros[None, :] # Scale and shift - # print("data type", a, b) - accumulator += tl.dot(a, b.to(a.dtype)) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) - # Store the result - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (..., K) float16 - qweight is of shape (K//8, N) int32 - scales is of shape (G, N) float16 - qzeros is of shape (G, N//8) int32 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = K // groupsize - - Returns C of shape (..., N) float16 - """ - assert x.shape[-1] == (qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" - assert x.is_contiguous(), "A must be contiguous" - - M, K = x.shape - N = qweight.shape[1] - # This is based on the possible BLOCK_SIZE_Ks - # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" - # # This is based on the possible BLOCK_SIZE_Ns - # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" - # # This is based on the possible BLOCK_SIZE_Ks - # assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" - - # output = torch.empty((M, N), device='cuda', dtype=torch.float16) - if output is None: - inplace = False - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - else: - inplace = True - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul4_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), - group_size, group_size == K, - ) - # return output - if not inplace: - return output - - -@triton.autotune( - configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] -) -@triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - bs_ptr, bzp_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, - group_size, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr - ): - """ - assert K % (BLOCK_SIZE_K * SPLIT_K) == 0 - """ - pid = tl.program_id(axis=0) - pid_sp_k = tl.program_id(axis=1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - # [BLOCK_M, BLOCK_K] - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - # [BLOCK_K, BLOCK_N] but repeated 8 times in N - b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn - # tl.static_print("shape", a_ptrs, b_ptrs, bs_ptrs, bzp_ptrs) - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B. - # [BLOCK_K, BLOCK_N] but repeated group_size times in K - bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \ - + offs_bn[None, :] * stride_bsn - # [BLOCK_K, BLOCK_N] but repeated in K and N - bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \ - + (offs_bn[None, :] // 8) * stride_bzpn - b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0 - bzp_shift_bits = (offs_bn[None, :] % 8) * 4 - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - bs = tl.load(bs_ptrs) - bzp = tl.load(bzp_ptrs) - # We accumulate along the K dimension. - int_b = (b >> b_shift_bits) & 0xF - int_bzp = (bzp >> bzp_shift_bits) & 0xF - b = ((int_b - int_bzp) * bs).to(a.dtype) - accumulator += tl.dot(a, b.to(a.dtype)) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0 - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - c = accumulator.to(c_ptr.dtype.element_ty) - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: - """ - """ - assert x.is_contiguous(), "A must be contiguous" - assert qweight.is_contiguous(), "B must be contiguous" - M, K = x.shape - N = scales.shape[1] - if output is None: - output = torch.zeros((M, N), device=x.device, dtype=x.dtype) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], - ) - matmul_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), - group_size, - ) - return output - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - ], - key=['K', 'N'], -) -@triton.jit -def dequantize_kernel( - # Pointers to matrices - b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr, - # Matrix dimensions - K, N, group_size, - stride_bk, stride_bn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, - stride_fpbk, stride_fpbn, - # Meta-parameters - BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, -): - """Dequantize tile [BLOCK_SIZE_K, BLOCK_SIZE_N] in full precision. - We should assert BLOCK_SIZE_N % 8 == 0. - weight[K // 8, N], scale[K // group_size, N], zp[K // group_size, N // group_size] - """ - k_block_idx = tl.program_id(axis=0) - n_block_idx = tl.program_id(axis=1) - offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_n = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - fpb_offs = offs_k[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn - b_offs = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn - bzp_offs = (offs_k[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn - bs_offs = (offs_k[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn - n_mask = offs_n[None, :] < N - k_mask = offs_k[:, None] < K - mask = n_mask & k_mask - int32_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0) - zp_b = tl.load(b_zp_ptr + bzp_offs, mask=mask, other=0.0) - scale_b = tl.load(b_scale_ptr + bs_offs, mask=mask, other=0.0) - b_shift = (offs_k[:, None] % 8) * 4 - bzp_shift = (offs_n[None, :] % 8) * 4 - fp_weight = (((int32_b >> b_shift) & 0xF) - ((zp_b >> bzp_shift) & 0xF)) * scale_b - tl.store(fpb_ptr + fpb_offs, fp_weight, mask=mask) - - -def dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size): - Kw, N = b.shape - K = Kw * 8 - fp_b = torch.ones((K, N), device=device, dtype=dtype) - grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), - triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - dequantize_kernel[grid]( - b, b_scale, b_zero_point, fp_b, - K, N, group_size, - b.stride(0), b.stride(1), - b_scale.stride(0), b_scale.stride(1), - b_zero_point.stride(0), b_zero_point.stride(1), - fp_b.stride(0), fp_b.stride(1) - ) - return fp_b - - -def matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, group_size=128, out=None): - """ - Matmul dequantize int4 s1 dequantize weight to `fp_b` and do fp16 torch.mm, - this is for `prefill` stage, since weight size is fixed so is dequantize overhead, - perfill stage have more tokens to amortize dequant cost. - """ - assert a.is_contiguous(), "Matrix A must be contiguous" - # assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - Kw, N = b.shape - if out is None: - # Allocates output. - out = torch.empty((M, N), device=a.device, dtype=a.dtype) - fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size) - torch.mm(a, fp_b, out=out) - fp_b = None - return out - - -def quantize_int4(weight, group_size=128, tp_rank=0): - # Weight shape: [H1 // 8, H2] - # Scale shape: [H1 // group_size, H2] - # zero_pint shape: [H1 // group_size, H2 // 8] - - weight = weight.transpose(1, 0) - h1, h2 = weight.shape - assert h1 % 8 == 0 and h2 % 8 == 0, "H1 {} H2 {}".format(h1, h2) - assert h2 % group_size == 0, "H1 {} H2 {}".format(h1, h2) - weight = weight.contiguous().view(-1, group_size).cuda(tp_rank) - weight_max = weight.amax(-1, keepdim=True) - weight_max = torch.where(weight_max < 0, 0, weight_max) - weight_min = weight.amin(-1, keepdim=True) - weight_min = torch.where(weight_min > 0, 0, weight_min) - weight_range = weight_max - weight_min - scale = weight_range / (2 ** 4 - 1) - zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32) - weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2) - int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device) - int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device) - zero_point = zero_point.view(h1, -1) - scale = scale.view(h1, -1) - # pack 8 int4 in an int32 number. - # Weight pack in row. - for pack in range(0, h2, 8): - for i in range(8): - int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4) - # zero point pack in col. - for pack in range(0, h1, 8): - for i in range(8): - int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4) - ''' - fp_weight = torch.zeros(h1, h2).half().to(weight.device) - for pack in range(0, h1 // 8): - for i in range(8): - fp_weight[pack * 8 + i, :] = \ - ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16 - print((fp_weight - weight).abs().sum()) - - fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device) - for pack in range(0, h1 // 8): - for i in range(8): - fp_zp[pack * 8 + i, :] = \ - (int_zero_point[pack, :] >> (i * 4)) & 15 - - print((fp_zp - zero_point).abs().sum()) - ''' - weight = None - return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size - - -def unpack_int4(weight, scale, zp): - """ - Test function to verify quantize int4 is correct. - Will not be used in model inference. - """ - weight = weight.transpose(1, 0) - scale = scale.transpose(1, 0) - zp = zp.transpose(1, 0) - h1, h2 = weight.shape - group_size = h2 * 8 // scale.shape[1] - group_num = scale.shape[1] - fp_weight = torch.zeros(h1, h2 * 8).half().to(weight.device) - fp_zero_point = torch.zeros(h1, group_num).to(weight.device) - for pack in range(0, h2): - for i in range(8): - fp_weight[:, pack * 8 + i] = (weight[:, pack] >> (i * 4)) & 0xF - for pack in range(0, h1 // 8): - for i in range(8): - fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF - for g in range(group_num): - fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \ - fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1) - return fp_weight.transpose(1, 0) - - -def test_int4(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b) - for _ in range(10): - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point) - torch.cuda.synchronize() - t2 = time.time() - triton_time = t2 - t1 - print("Triton time cost", (t2 - t1)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - print("Torch time cost", (t2 - t1)) - return triton_time, torch_time - - -def test_correct_int4_s1(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = dequantize_int4(int_b, b_scale, b_zero_point, a.device, a.dtype, group_size) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_correct_int4_s2(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = unpack_int4(int_b, b_scale, b_zero_point) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_correct_int4_gptq(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = unpack_int4(int_b, b_scale, b_zero_point) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_gptq(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[4, 8, 16, 32, 64, 128] + [ - 128 * i for i in range(2, 33, 2) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton-s1', 'dequantize', 'triton-s2', 'triton-gptq'], - # Label name for the lines - line_names=["cuBLAS", "Triton-s1", "Dequant(GB/s)", "Triton-s2", "Triton-gptq"], - # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-'), ('yellow', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) -def benchmark(M, provider): - K = 4096 - N = 4096 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s1': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s1(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s2': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s2(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'dequantize': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_int4(intb, b_scale, bzp, 'cuda', torch.float16, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) - if provider == 'triton-gptq': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_gptq(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - return perf(ms), perf(max_ms), perf(min_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - t1, t2 = test_int4(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - print("Triton time {} Torch time {}".format(st1, st2)) - - -if __name__ == "__main__": - # test_correct_int4_s1() - # test_correct_int4_s2() - # test_correct_int4_gptq() - benchmark.run(show_plots=True, print_data=True) - exit() - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py deleted file mode 100644 index e2c5c0dc95..0000000000 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - ], - key=['K', 'N'], -) - - -@triton.jit -def dequantize_kernel( - # Pointers to matrices - b_ptr, b_scale_ptr, fpb_ptr, - # Matrix dimensions - K, N, - stride_bk, stride_bn, - stride_fpbk, stride_fpbn, - # Meta-parameters - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - k_block_idx = tl.program_id(axis=0) - n_block_idx = tl.program_id(axis=1) - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_n = tl.arange(0, BLOCK_SIZE_N) - b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn - fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn - bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] - n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N - mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask - int_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0) - scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask, other=0.0) - tl.store(fpb_ptr + fpb_offs, int_b * scale_b, mask=mask) - - -def matmul_dequantize_int8(a, b, b_scale, out=None): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - # assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - K, N = b.shape - if out == None: - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - else: - c = out - fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype) - grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - dequantize_kernel[grid]( - b, b_scale, fp_b, - K, N, - b.stride(0), b.stride(1), - fp_b.stride(0), fp_b.stride(1) - ) - torch.mm(a, fp_b, out=c) - return c - - -def quantize_int8(weight, axis=0, tp_rank=0): - # Weight shape: [H1, H2] - # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. - weight = (weight / scale).to(torch.int8) - if axis == 0: - weight = weight.t().contiguous().t() - scale = scale.squeeze(axis) - return weight.contiguous().cuda(tp_rank), scale.contiguous().cuda(tp_rank) - - -def test_int8(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale = quantize_int8(b) - for _ in range(10): - triton_output = matmul_dequantize_int8(a, int_b, b_scale.unsqueeze(0)) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - triton_output = matmul_dequantize_int8(a, int_b, b_scale.unsqueeze(0)) - torch.cuda.synchronize() - t2 = time.time() - triton_time = t2 - t1 - print("Triton time cost", (t2 - t1)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - print("Torch time cost", (t2 - t1)) - return triton_time, torch_time - - -def test_correct_int8(M=512, K=4096, N=4096): - import time - - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale = quantize_int8(b) - cos = torch.nn.CosineSimilarity(0) - triton_output = matmul_dequantize_int8(a, int_b, b_scale) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos ", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i for i in range(1, 33) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton'], - # Label name for the lines - line_names=["cuBLAS", "Triton"], - # Line styles - styles=[('green', '-'), ('blue', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) - - -def benchmark(M, N, K, provider): - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - if provider == 'triton': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - intb, b_scale = quantize_int8(b) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int8(a, intb, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - return perf(ms), perf(min_ms), perf(max_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - t1, t2 = test_int8(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - print("Triton time {} Torch time {}".format(st1, st2)) - - -if __name__ == "__main__": - test_correct_int8() - benchmark.run(show_plots=True, print_data=True) - - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) \ No newline at end of file diff --git a/lightllm/common/fused_moe/__init__.py b/lightllm/common/basemodel/triton_kernel/fused_moe/__init__.py similarity index 100% rename from lightllm/common/fused_moe/__init__.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/__init__.py diff --git a/lightllm/common/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py similarity index 100% rename from lightllm/common/fused_moe/deepep_scatter_gather.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py similarity index 97% rename from lightllm/common/fused_moe/grouped_fused_moe.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 758d83ba34..97075e9123 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -28,7 +28,7 @@ from .moe_kernel_configs import MoeGroupedGemmKernelConfig from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.utils.torch_ops_utils import direct_register_custom_op from lightllm.common.triton_utils.autotuner import autotune @@ -387,8 +387,8 @@ def grouped_matmul_kernel( k, # int n, # int topk_num, # int - token_scale_ptr, # [1,] for per tensor quant, or [token_num, hidden_dim // block_size] for per token, group quant - weight_scale_ptr, # [expert_num,] or [export_num, n // block_size_n, k // block_size_k] + token_scale_ptr, # [token_num,] for pertoken quant, or [token_num,hidden_dim//block_size] for per group quant + weight_scale_ptr, # [expert_num, n] or [export_num, n // block_size_n, k // block_size_k] weight_scale_stride0, weight_scale_stride1, weight_scale_stride2, @@ -497,8 +497,18 @@ def grouped_matmul_kernel( b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1 else: - a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last") - b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last") + # per token scale quant + if TOKEN_INPUT_USE_TMA: + assert MUL_ROUTED_WEIGHT is True + a_scale_ptrs = token_scale_ptr + (token_start_index + tl.arange(0, BLOCK_SIZE_M))[:, None] + else: + a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num)[:, None] + + a_scale = tl.load(a_scale_ptrs, eviction_policy="evict_last") + b_scale = tl.load( + weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1, + eviction_policy="evict_last", + ) ab_scale = a_scale * b_scale if NEED_TRANS: @@ -745,8 +755,12 @@ def grouped_matmul( if use_fp8_w8a8: # 当权重使用 block wise 量化时,激活也使用 per token, group size 量化 if block_size_k == 0: - token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant(token_inputs, token_input_scale) + # input 使用 per token 量化 + token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant( + token_inputs, token_input_scale, use_per_token_if_dynamic=True + ) else: + # input 使用 per group quant 量化 _m, _k = token_inputs.shape assert _k % block_size_k == 0 token_inputs, token_input_scale = per_token_group_quant_fp8( diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py similarity index 96% rename from lightllm/common/fused_moe/grouped_fused_moe_ep.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 5cc0d7a9be..2c6d013bd5 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -6,13 +6,15 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch.distributed as dist from lightllm.utils.log_utils import init_logger -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( + silu_and_mul_masked_post_quant_fwd, +) +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, tma_align_input_scale, ) -from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather +from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank from lightllm.common.triton_utils.autotuner import Autotuner import numpy as np diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py similarity index 100% rename from lightllm/common/fused_moe/grouped_topk.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py diff --git a/lightllm/common/fused_moe/moe_kernel_configs.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py similarity index 100% rename from lightllm/common/fused_moe/moe_kernel_configs.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul_config.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py diff --git a/lightllm/common/fused_moe/moe_sum_recude_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py similarity index 100% rename from lightllm/common/fused_moe/moe_sum_recude_config.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py diff --git a/lightllm/common/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py similarity index 100% rename from lightllm/common/fused_moe/moe_sum_reduce.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py diff --git a/lightllm/common/fused_moe/softmax_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/softmax_topk.py similarity index 100% rename from lightllm/common/fused_moe/softmax_topk.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/softmax_topk.py diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py similarity index 96% rename from lightllm/common/fused_moe/topk_select.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 5206800efc..72c3a381ed 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -22,7 +22,7 @@ from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops from typing import Callable, List, Optional, Tuple -from lightllm.common.fused_moe.softmax_topk import softmax_topk +from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk from lightllm.common.triton_utils.autotuner import Autotuner use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"] @@ -177,8 +177,8 @@ def select_experts( scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, ): - from lightllm.common.fused_moe.topk_select import fused_topk - from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import fused_topk + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_topk import triton_grouped_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: diff --git a/lightllm/common/quantization/triton_quant/__init__.py b/lightllm/common/basemodel/triton_kernel/norm/__init__.py similarity index 100% rename from lightllm/common/quantization/triton_quant/__init__.py rename to lightllm/common/basemodel/triton_kernel/norm/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/layernorm.py b/lightllm/common/basemodel/triton_kernel/norm/layernorm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/layernorm.py rename to lightllm/common/basemodel/triton_kernel/norm/layernorm.py diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py similarity index 100% rename from lightllm/models/qwen3/triton_kernel/qk_norm.py rename to lightllm/common/basemodel/triton_kernel/norm/qk_norm.py diff --git a/lightllm/common/basemodel/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/rmsnorm.py rename to lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py diff --git a/lightllm/common/quantization/triton_quant/fp8/__init__.py b/lightllm/common/basemodel/triton_kernel/quantization/__init__.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/__init__.py rename to lightllm/common/basemodel/triton_kernel/quantization/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/bmm_scaled_fp8.py b/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/bmm_scaled_fp8.py rename to lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py diff --git a/lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py b/lightllm/common/basemodel/triton_kernel/quantization/q_per_head_fp8_quant.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py rename to lightllm/common/basemodel/triton_kernel/quantization/q_per_head_fp8_quant.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py similarity index 93% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py index 7c76e82c9e..f14e8b2833 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py @@ -11,8 +11,8 @@ from lightllm.utils.device_utils import triton_support_tensor_descriptor, is_5090_gpu -class Fp8ScaledMMKernelConfig(KernelConfigs): - kernel_name: str = "fp8_scaled_mm_per_token" +class ScaledMMKernelConfig(KernelConfigs): + kernel_name: str = "scaled_mm_per_token" @classmethod @lru_cache(maxsize=200) @@ -105,6 +105,7 @@ def _scaled_mm_per_token( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, + ACC_DTYPE: tl.constexpr, ): pid = tl.program_id(0) m_block_num = tl.cdiv(M, BLOCK_M) @@ -134,7 +135,7 @@ def _scaled_mm_per_token( a_s = tl.load(Ascale_ptrs) b_s = tl.load(Bscale_ptrs) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): if USE_TMA: @@ -155,6 +156,7 @@ def _scaled_mm_per_token( a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk + acc = acc.to(tl.float32) acc = acc * a_s[:, None] * b_s[None, :] acc = acc.to(out.dtype.element_ty) @@ -206,13 +208,13 @@ def _get_static_key(A, B, out_dtype): @autotune( - kernel_name="fp8_scaled_mm_per_token:v3", + kernel_name="scaled_mm_per_token:v1", configs_gen_func=get_test_configs, static_key_func=_get_static_key, run_key_func=lambda A: A.shape[0], mutates_args=["out"], ) -def fp8_scaled_mm_per_token( +def scaled_mm_per_token( A: torch.Tensor, B: torch.Tensor, Ascale: torch.Tensor, @@ -221,7 +223,7 @@ def fp8_scaled_mm_per_token( out: torch.Tensor, run_config=None, ) -> torch.Tensor: - """w8a8fp8 per-token quantization mm. + """w8a8 per-token quantization mm (supports fp8 and int8). Args: A: Matrix A with shape of [M, K]. @@ -239,7 +241,7 @@ def fp8_scaled_mm_per_token( M, K = A.shape _, N = B.shape if not run_config: - run_config = Fp8ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) + run_config = ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) NEED_N_MASK = N % run_config["BLOCK_N"] != 0 NEED_K_MASK = K % run_config["BLOCK_K"] != 0 grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) @@ -283,6 +285,8 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): B_desc = None out_desc = None + ACC_DTYPE = tl.int32 if A.dtype == torch.int8 else tl.float32 + _scaled_mm_per_token[grid]( A=A, A_desc=A_desc, @@ -305,12 +309,17 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): B_IS_TRANS=B_is_trans, NEED_N_MASK=NEED_N_MASK, NEED_K_MASK=NEED_K_MASK, + ACC_DTYPE=ACC_DTYPE, **run_config, ) return out +fp8_scaled_mm_per_token = scaled_mm_per_token +int8_scaled_mm_per_token = scaled_mm_per_token + + if __name__ == "__main__": import time import os @@ -324,7 +333,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): M_list = [1, 2, 4, 8, 16, 32, 48] print(f"{'='*80}") - print(f"Starting Autotune for FP8 Scaled MM (N={N}, K={K})") + print(f"Starting Autotune for Scaled MM (N={N}, K={K})") print(f"M values to test: {M_list}") print(f"Total configs per M: {len(get_test_configs())}") print(f"{'='*80}\n") @@ -360,7 +369,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): gt_C = d_A.mm(d_B) # 运行kernel验证正确性 - fp8_scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) + scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) # 计算cosine similarity cosine_sim = F.cosine_similarity(out_verify.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) @@ -390,7 +399,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): A = test_data[M]["A"] Ascale = test_data[M]["Ascale"] out = test_data[M]["out"] - fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) print(f"[M={M}] Autotune completed!") Autotuner.end_autotune_warmup() @@ -418,7 +427,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): gt_C = d_A.mm(d_B) # 运行一次确保结果正确 - fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) sgl_res = fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) cosine_sim = F.cosine_similarity(out.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) @@ -437,7 +446,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): ms_sgl = triton.testing.do_bench(fn_sgl, warmup=25, rep=100) # Our kernel - fn_ours = lambda: fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + fn_ours = lambda: scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) ms_ours = triton.testing.do_bench_cudagraph(fn_ours, rep=100) print(f"[M={M}] BF16: {ms_bf16:.3f} ms") diff --git a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py deleted file mode 100644 index 4f3f6a385c..0000000000 --- a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py +++ /dev/null @@ -1,376 +0,0 @@ -import time -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=2, num_warps=4), - triton.Config({}, num_stages=2, num_warps=2), - triton.Config({}, num_stages=2, num_warps=1), - ], - key=['K'], -) -@triton.jit -def quantize_int8_perrow_kernel( - fpa_ptr, a_ptr, as_ptr, - M, K, - stride_fpam, stride_fpak, - stride_am, stride_ak, - stride_asm, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - - fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - a_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - a_max = tl.maximum(a_max, tl.max(tl.abs(fpa), axis=1)) - fpa_ptrs += BLOCK_SIZE_K * stride_fpak - a_scale = (a_max / 127.) - fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - inta = (fpa / a_scale[:, None]).to(tl.int8) - tl.store(a_ptrs, inta, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K) - fpa_ptrs += BLOCK_SIZE_K * stride_fpak - a_ptrs += BLOCK_SIZE_K * stride_ak - as_offs = pid_m * BLOCK_SIZE_M * stride_asm + tl.arange(0, BLOCK_SIZE_M) - tl.store(as_ptr + as_offs, a_scale) - - -def quantize_int8_perrow(fpa): - a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8) - a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=fpa.dtype) - M, K = fpa.shape - BLOCK_SIZE_M = 1 - BLOCK_SIZE_K = triton.next_power_of_2(K) - grid = (M // BLOCK_SIZE_M,) - quantize_int8_perrow_kernel[grid]( - fpa, a, a_scale, - M, K, - fpa.stride(0), fpa.stride(1), - a.stride(0), a.stride(1), - a_scale.stride(0), - BLOCK_SIZE_M, BLOCK_SIZE_K, - ) - return a, a_scale - - -@triton.autotune( - configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] -) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_asm, - stride_bk, stride_bn, - stride_bsn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - pid_sp_k = tl.program_id(axis=1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - as_ptrs = as_ptr + offs_am * stride_asm - bs_ptrs = bs_ptr + offs_bn * stride_bsn - a_scale = tl.load(as_ptrs, mask=offs_am < M, other=0.0) - b_scale = tl.load(bs_ptrs, mask=offs_bn < N, other=0.0) - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(c_ptr.dtype.element_ty) - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -def matmul_quantize_int8(fpa, b, b_scale, out=None): - a, a_scale = quantize_int8_perrow(fpa) - # a, a_scale = quantize_int8(fpa, axis=1) - return matmul_int8(a, a_scale, b, b_scale, out) - - -def matmul_int8(a, a_scale, b, b_scale, out=None): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - M, K = a.shape - K, N = b.shape - # Allocates output. - if out == None: - c = torch.zeros((M, N), device=a.device, dtype=torch.float16) - else: - c = out.fill_(0.) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], - ) - matmul_kernel[grid]( - a, a_scale, b, b_scale, c, - M, N, K, - a.stride(0), a.stride(1), - a_scale.stride(0), - b.stride(0), b.stride(1), - b_scale.stride(0), - c.stride(0), c.stride(1), - ) - return c - - -def quantize_int8(weight, axis=0, tp_rank=0): - # Weight shape: [H1, H2] - # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. - weight = (weight / scale).to(torch.int8) - # col major will accelerate i8xi8 kernel. - if axis == 0: - weight = weight.t().contiguous().t() - scale = scale.squeeze(axis) - return weight.contiguous().cuda(tp_rank), scale.contiguous().cuda(tp_rank) - - -def test_correct_int8(M=32, N=4096, K=4096): - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_a, scale_a = quantize_int8_perrow(a) - cos = torch.nn.CosineSimilarity(0) - print("Quantization cos", cos((int_a * scale_a.unsqueeze(1)).flatten().to(torch.float32), a.flatten().to(torch.float32))) - int_b, scale_b = quantize_int8(b, axis=0) - triton_output = matmul_int8(int_a, scale_a, int_b, scale_b) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - cos = torch.nn.CosineSimilarity(0) - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_int8(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16).contiguous() - int_b, scale_b = quantize_int8(b, axis=0) - for _ in range(10): - # int_a, a_scale = quantize_int8(a, 1) - int_a, a_scale = quantize_int8_perrow(a) - triton_output = matmul_int8(int_a, a_scale, int_b, scale_b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - #int_a, a_scale, _ = quantize_int8(a, 1) - int_a, a_scale = quantize_int8_perrow(a) - torch.cuda.synchronize() - qt2 = time.time() - for _ in range(iters): - triton_output = matmul_int8(int_a, a_scale, int_b, scale_b) - torch.cuda.synchronize() - t2 = time.time() - quant_time = qt2 - t1 - triton_time = t2 - qt2 - triton_tflops = 2 * M * N * K * 1e-12 / (triton_time / iters) - quant_bandwith = 2 * M * K * 1e-9 / (quant_time / iters) - print("Triton time cost: {} (tflops {}) + quant: {} (bandwidth {})".format( - triton_time, triton_tflops, quant_time, quant_bandwith)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - torch_tflops = 2 * M * N * K * 1e-12 / (torch_time / iters) - print("Torch time cost: {} (tflops {})".format(t2 - t1, torch_tflops)) - return triton_time, torch_time, quant_time - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i * 2 for i in range(1, 17) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton-i8', 'triton-quant-i8', 'quant-perrow'], - # Label name for the lines - line_names=["cuBLAS", "Triton-i8", "Triton-Quant-i8", "Quant-perrow(GB/s)"], - # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) -def benchmark(M, provider): - K = 10240 - N = 27392 * 2 // 8 - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - int_a, a_scale = quantize_int8(a, axis=1) - int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_int8(int_a, a_scale, int_b, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-quant-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_quantize_int8(a, int_b, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'quant-perrow': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: quantize_int8_perrow(a), quantiles=quantiles) - perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) - return perf(ms), perf(min_ms), perf(max_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - st3 = 0 - t1, t2, t3 = test_int8(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - st3 += t3 - print("Triton time {} Torch time {} Quant time {}".format(st1, st2, st3)) - - -if __name__ == "__main__": - test_correct_int8() - benchmark.run(show_plots=True, print_data=True) - - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 26f59258cd..1e47454490 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -1,11 +1,10 @@ import yaml import collections from .registry import QUANTMETHODS -from .torchao_quant import * -from .w8a8_quant import * -from .triton_quant.triton_quant import * -from .deepgemm_quant import * -from .awq_quant import * +from .w8a8 import * +from .deepgemm import * +from .awq import * +from .no_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -37,7 +36,7 @@ def _mapping_quant_method(self): if self.hf_quantization_method == "fp8": block_size = self.hf_quantization_config.get("weight_block_size", None) if block_size == [128, 128]: - from lightllm.common.quantization.deepgemm_quant import HAS_DEEPGEMM + from lightllm.common.quantization.deepgemm import HAS_DEEPGEMM if HAS_DEEPGEMM: self.quant_type = "deepgemm-fp8w8a8-b128" @@ -78,4 +77,6 @@ def get_quant_type(self, layer_num, name): def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) - return QUANTMETHODS.get(quant_type) + quant_method = QUANTMETHODS.get(quant_type) + quant_method.hf_quantization_config = self.hf_quantization_config + return quant_method diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq.py similarity index 50% rename from lightllm/common/quantization/awq_quant.py rename to lightllm/common/quantization/awq.py index 8c04cdcea9..f3c7623975 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq.py @@ -1,34 +1,42 @@ -import os import torch -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from typing import Any -from typing import TYPE_CHECKING, Optional, Tuple -from lightllm.utils.dist_utils import get_current_device_id - -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from typing import Any, Optional, Tuple, List -if HAS_VLLM: - awq_dequantize = vllm_ops.awq_dequantize - awq_gemm = vllm_ops.awq_gemm - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - marlin_permute_scales, - awq_to_marlin_zero_points, - should_use_atomic_add_reduce, - marlin_make_empty_g_idx, - marlin_make_workspace_new, - ) - from vllm.scalar_type import scalar_types - - TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + if HAS_VLLM: + awq_dequantize = vllm_ops.awq_dequantize + awq_gemm = vllm_ops.awq_gemm + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + marlin_permute_scales, + awq_to_marlin_zero_points, + should_use_atomic_add_reduce, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + ) + from vllm.scalar_type import scalar_types + + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + else: + awq_dequantize = None + awq_gemm = None + TYPE_MAP = {} +except ImportError: + HAS_VLLM = False + awq_dequantize = None + awq_gemm = None + TYPE_MAP = {} class AWQBaseQuantizationMethod(QuantizationMethod): @@ -39,16 +47,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -57,7 +66,7 @@ def method_name(self): return "awq-base" -@QUANTMETHODS.register("awq") +@QUANTMETHODS.register("awq", platform="cuda") class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): def __init__(self): super().__init__() @@ -72,21 +81,21 @@ def __init__(self): def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 if NEED_DEQUANT_WEIGHT: @@ -99,8 +108,34 @@ def apply( out.add_(bias) return out + def _create_weight( + self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_out_dims = [_out_dim // self.pack_factor for _out_dim in out_dims] + weight_scale_out_dims = out_dims + weight_zero_point_out_dims = weight_out_dims + mm_param = WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=weight_out_dims, + weight_split_dim=-1, + weight_scale_out_dims=weight_scale_out_dims, + weight_scale_split_dim=-1, + weight_zero_point_out_dims=weight_zero_point_out_dims, + weight_zero_point_split_dim=-1, + ) + return mm_param, mm_param_list -@QUANTMETHODS.register("awq_marlin") + +@QUANTMETHODS.register("awq_marlin", platform="cuda") class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): def __init__(self): super().__init__() @@ -115,76 +150,27 @@ def __init__(self): self.vllm_quant_type = TYPE_MAP[self.nbits] self.has_weight_scale = True self.has_weight_zero_point = True + self.tile_size = 16 @property def method_name(self): return "awq_marlin" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack): raise NotImplementedError("AWQ online quantization is not supported yet.") - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return True - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - weight = self._process_weight_after_loading(weight.cuda(get_current_device_id())) - weight_scale = self._process_weight_scale_after_loading( - weight_scale.cuda(get_current_device_id()).to(dtype_type) - ) - weight_zero_point = self._process_weight_zero_point_after_loading( - weight_zero_point.cuda(get_current_device_id()) - ) - return weight, weight_scale, weight_zero_point - - def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - self.k = weight.shape[0] - self.n = weight.shape[1] * self.pack_factor - return vllm_ops.awq_marlin_repack( - weight, - size_k=weight.shape[0], - size_n=weight.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - - def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - group_size = self.hf_quantization_config["group_size"] - return marlin_permute_scales( - weight_scale, - size_k=weight_scale.shape[0] * group_size, - size_n=weight_scale.shape[1], - group_size=self.hf_quantization_config["group_size"], - ) - - def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return awq_to_marlin_zero_points( - weight_zero_point, - size_k=weight_zero_point.shape[0], - size_n=weight_zero_point.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) use_atomic_add = should_use_atomic_add_reduce( @@ -219,6 +205,81 @@ def apply( out.add_(bias) return out + def _create_weight( + self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) + self.n = out_dim + self.k = in_dim + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty( + expert_prefix + (in_dim // self.tile_size, out_dim * self.tile_size // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_out_dims = [_out_dim * self.tile_size // self.pack_factor for _out_dim in out_dims] + weight_scale_out_dims = out_dims + weight_zero_point_out_dims = [_out_dim // self.pack_factor for _out_dim in out_dims] + mm_param = WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=weight_out_dims, + weight_split_dim=-1, + weight_scale_out_dims=weight_scale_out_dims, + weight_scale_split_dim=-1, + weight_zero_point_out_dims=weight_zero_point_out_dims, + weight_zero_point_split_dim=-1, + ) + return mm_param, mm_param_list + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + if weight is None: + return + device_id = get_current_device_id() + repack_weight = vllm_ops.awq_marlin_repack( + weight.cuda(device_id), + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + weight_pack.weight.copy_(repack_weight) + weight_pack.load_ok[0] = True + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + if weight_scale is None: + return + group_size = self.hf_quantization_config["group_size"] + device_id = get_current_device_id() + repack_weight_scale = marlin_permute_scales( + weight_scale.cuda(device_id), + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + weight_pack.weight_scale.copy_(repack_weight_scale) + weight_pack.load_ok[1] = True + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack) -> None: + if weight_zero_point is None: + return + device_id = get_current_device_id() + repack_weight_zero_point = awq_to_marlin_zero_points( + weight_zero_point.cuda(device_id), + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + weight_pack.weight_zero_point.copy_(repack_weight_zero_point) + weight_pack.load_ok[2] = True + return + # adapted from # https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py new file mode 100644 index 0000000000..137455a821 --- /dev/null +++ b/lightllm/common/quantization/deepgemm.py @@ -0,0 +1,135 @@ +import torch +from typing import Optional, List, Union, Tuple + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + import deep_gemm + + HAS_DEEPGEMM = True +except ImportError: + HAS_DEEPGEMM = False + + +class DeepGEMMBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + raise NotImplementedError("Not implemented") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") + + @property + def method_name(self): + return "deepgemm-base" + + +@QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"], platform="cuda") +class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 128 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "weight_scale_inv" + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp8w8a8-b128" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_quant_kernel import weight_quant + + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight.copy_(weight) + output.weight_scale.copy_(scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None + alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty + m, k = input_tensor.shape + n = qweight.shape[0] + if input_scale is None: + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, + self.block_size, + dtype=qweight.dtype, + column_major_scales=True, + scale_tma_aligned=True, + alloc_func=alloc_func, + ) + + if out is None: + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) + return out + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + weight_scale_out_dims = [(_out_dim + self.block_size - 1) // self.block_size for _out_dim in out_dims] + divisible_by_block_size = [_out_dim % self.block_size != 0 for _out_dim in out_dims] + if sum(divisible_by_block_size) > 1: + raise ValueError( + f"out_dims only contains one dim can not be divisible \ + by block_size {self.block_size}, but got {out_dims}" + ) + weight_scale_out_dim = sum(weight_scale_out_dims) + weight_scale_in_dim = (in_dim + self.block_size - 1) // self.block_size + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (weight_scale_out_dim, weight_scale_in_dim), dtype=torch.float32 + ).cuda(device_id) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=weight_scale_out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + +def _deepgemm_fp8_nt(a_tuple, b_tuple, out): + if HAS_DEEPGEMM: + if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): + return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) + if hasattr(deep_gemm, "fp8_gemm_nt"): + return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) + raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py deleted file mode 100644 index 7dbd3806b9..0000000000 --- a/lightllm/common/quantization/deepgemm_quant.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import torch -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( - per_token_group_quant_fp8, - tma_align_input_scale, -) -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack -try: - HAS_DEEPGEMM = True - import deep_gemm -except: - HAS_DEEPGEMM = False - - -class DeepGEMMBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" - - def quantize(self, weight: torch.Tensor): - """ """ - pass - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - raise NotImplementedError("Not implemented") - - @property - def method_name(self): - return "deepgemm-base" - - -@QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"]) -class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.block_size = 128 - self.weight_suffix = None - self.weight_zero_point_suffix = None - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - @property - def method_name(self): - return "deepgemm-fp8w8a8-b128" - - def quantize(self, weight: torch.Tensor): - from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - - weight, scale = weight_quant(weight, self.block_size) - return weight, scale, None - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - input_scale = None - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - m, k = input_tensor.shape - n = qweight.shape[1] - if input_scale is None: - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, - self.block_size, - dtype=qweight.dtype, - column_major_scales=True, - scale_tma_aligned=True, - alloc_func=alloc_func, - ) - - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out) - return out - - -def _deepgemm_fp8_nt(a_tuple, b_tuple, out): - if HAS_DEEPGEMM: - if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): - return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) - if hasattr(deep_gemm, "fp8_gemm_nt"): - return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) - raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py new file mode 100644 index 0000000000..3bf023f8a1 --- /dev/null +++ b/lightllm/common/quantization/no_quant.py @@ -0,0 +1,60 @@ +import torch +from typing import Optional, List, Union, Tuple + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS + + +@QUANTMETHODS.register("none", platform="musa") +@QUANTMETHODS.register("none", platform="cuda") +class NoQuantization(QuantizationMethod): + """No quantization - uses full precision weights.""" + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + weight = weight_pack.weight.t() + if out is None: + shape = (input_tensor.shape[0], weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device) + else: + out = torch.empty(shape, dtype=dtype, device=device) + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype).cuda(device_id) + mm_param = WeightPack(weight=weight, weight_scale=None, weight_zero_point=None) + # weight layout is (out_dim, in_dim), so the split dimension is -2. + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + ) + return mm_param, mm_param_list + + def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + return + + @property + def method_name(self): + return "none" diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 9b629bcaf1..95d8d806f9 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,38 +1,60 @@ import torch from abc import ABC, abstractmethod +from dataclasses import dataclass from lightllm.utils.dist_utils import get_current_device_id -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, List, Tuple -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +@dataclass +class WeightPack: + weight: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + def __post_init__(self): + self.load_ok = [False, self.weight_scale is None, self.weight_zero_point is None] + + def get_expert(self, expert_idx: int): + assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" + weight = self.weight[expert_idx] + weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() - self.weight_suffix = None + self.weight_suffix = "weight" self.weight_scale_suffix = None self.weight_zero_point_suffix = None self.act_scale_suffix = None self.has_weight_scale: bool = None self.has_weight_zero_point: bool = None + self.group_size: int = -1 # -1表示不分组即per-channel量化,其他表示分组大小 + self.pack_factor: int = 1 + # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None @abstractmethod - def quantize(self, weights: torch.Tensor): + def quantize( + self, + weight: torch.Tensor, + output: WeightPack, + ) -> None: pass @abstractmethod def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - bias: Optional[torch.Tensor] = None, + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass @@ -41,20 +63,87 @@ def apply( def method_name(self): pass - def weight_need_quanted(self, weight: torch.Tensor) -> bool: + def create_weight( + self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int + ) -> Tuple[WeightPack, List[WeightPack]]: + return self._create_weight( + out_dims=out_dims, + in_dim=in_dim, + dtype=dtype, + device_id=device_id, + ) + + def create_moe_weight( + self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int + ) -> Tuple[WeightPack, List[WeightPack]]: + return self._create_weight( + out_dims=out_dims, + in_dim=in_dim, + dtype=dtype, + device_id=device_id, + num_experts=num_experts, + ) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: + if self._check_weight_need_quanted(weight): + self.quantize(weight, weight_pack) + weight_pack.load_ok = [True, True, True] + return + weight_pack.weight[:].copy_(weight) + weight_pack.load_ok[0] = True + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack) -> None: + if weight_scale is None: + return + weight_pack.weight_scale.copy_(weight_scale) + weight_pack.load_ok[1] = True + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack) -> None: + if weight_zero_point is None: + return + weight_pack.weight_zero_point.copy_(weight_zero_point) + weight_pack.load_ok[2] = True + return + + def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: # 判断一个 weight 是否需要进行量化操作。 return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return False - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - return weight, weight_scale, weight_zero_point + def _create_weight( + self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + pass + + def _split_weight_pack( + self, + weight_pack: WeightPack, + weight_out_dims: List[int], + weight_split_dim: Optional[int], + weight_scale_out_dims: List[int] = None, + weight_scale_split_dim: Optional[int] = None, + weight_zero_point_out_dims: List[int] = None, + weight_zero_point_split_dim: Optional[int] = None, + ) -> List[WeightPack]: + # only support per-channel or block-wise quantization for now. + mm_param_list: List[WeightPack] = [] + weight = torch.split(weight_pack.weight, weight_out_dims, dim=weight_split_dim) + weight_scale = ( + [None] * len(weight_out_dims) + if weight_pack.weight_scale is None + else (torch.split(weight_pack.weight_scale, weight_scale_out_dims, dim=weight_scale_split_dim)) + ) + # the ndim of weight_zero_point is the same as weight_scale. + weight_zero_point = ( + [None] * len(weight_out_dims) + if weight_pack.weight_zero_point is None + else ( + torch.split(weight_pack.weight_zero_point, weight_zero_point_out_dims, dim=weight_zero_point_split_dim) + ) + ) + for weight, weight_scale, weight_zero_point in zip(weight, weight_scale, weight_zero_point): + mm_param_list.append( + WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + ) + return mm_param_list diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 674a22b60f..c9baa64e27 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,28 +1,31 @@ from .quantize_method import QuantizationMethod -from typing import Type class QuantMethodFactory: def __init__(self): self._quant_methods = {} - def register(self, names): + def register(self, names, platform="cuda"): def decorator(cls): local_names = names if isinstance(local_names, str): local_names = [local_names] for n in local_names: - self._quant_methods[n] = cls + if n not in self._quant_methods: + self._quant_methods[n] = {} + self._quant_methods[n][platform] = cls return cls return decorator - def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: - if key == "none": - return None - quant_method_class = self._quant_methods.get(key) - if not quant_method_class: + def get(self, key, platform="cuda", *args, **kwargs) -> "QuantizationMethod": + quant_method_class_dict = self._quant_methods.get(key) + if not quant_method_class_dict: raise ValueError(f"QuantMethod '{key}' not supported.") + + quant_method_class = quant_method_class_dict.get(platform) + if quant_method_class is None: + raise ValueError(f"QuantMethod '{key}' for platform '{platform}' not supported.") return quant_method_class() diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py deleted file mode 100644 index ba4115b1d9..0000000000 --- a/lightllm/common/quantization/torchao_quant.py +++ /dev/null @@ -1,169 +0,0 @@ -import os -import torch -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack - -try: - HAS_TORCH_AO = True - from torchao.quantization import ( - int4_weight_only, - int8_weight_only, - float8_weight_only, - fpx_weight_only, - int8_dynamic_activation_int8_weight, - float8_dynamic_activation_float8_weight, - quantize_, - ) - from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - ) -except: - HAS_TORCH_AO = False - - -class AOBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" - assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" - self.quant_func = None - - def quantize(self, weight: torch.Tensor): - """ """ - dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) - quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight, None, None - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - weights = weight_pack.weight - bias = weight_pack.bias - return F.linear(input_tensor, weights, bias) - - @property - def method_name(self): - return "ao-base" - - -@QUANTMETHODS.register(["ao-w4a16-256"]) -class AOW4A16QuantizationMethodGroup256(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 256 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-256" - - -@QUANTMETHODS.register(["ao-w4a16-128"]) -class AOW4A16QuantizationMethodGroup128(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 128 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-128" - - -@QUANTMETHODS.register(["ao-w4a16-64"]) -class AOW4A16QuantizationMethodGroup64(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 64 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-64" - - -@QUANTMETHODS.register(["ao-w4a16-32"]) -class AOW4A16QuantizationMethodGroup32(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 32 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-32" - - -@QUANTMETHODS.register("ao-w8a8") -class AOW8A8QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.quant_func = int8_dynamic_activation_int8_weight() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w8a8" - - -@QUANTMETHODS.register("ao-w8a16") -class AOW8A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.quant_func = int8_weight_only() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w8a16" - - -@QUANTMETHODS.register("ao-fp8w8a16") -class AOFP8W8A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - assert is_cuda_8_9, "FP8 requires GPU with compute capability >= 8.9" - self.quant_func = float8_weight_only() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-fp8w8a16" - - -@QUANTMETHODS.register("ao-fp6w6a16") -class AOFP6W6A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - assert TORCH_VERSION_AT_LEAST_2_5, "torchao fp6 requires torch >=2.5" - self.quant_func = fpx_weight_only(3, 2) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-fp6w6a16" diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py deleted file mode 100644 index 410f925a5e..0000000000 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from lightllm.common.quantization.quantize_method import QuantizationMethod -from lightllm.common.quantization.registry import QUANTMETHODS -from .fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul -from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack - - -class TritonBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - - def quantize(self, weight: torch.Tensor): - pass - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - raise NotImplementedError("Not implemented") - - -@QUANTMETHODS.register(["triton-fp8w8a8-block128"]) -class TritonFP8w8a8QuantizationMethod(TritonBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.is_moe = False - self.block_size = 128 - self.weight_suffix = None - self.weight_zero_point_suffix = None - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - def quantize(self, weight: torch.Tensor): - # TODO block-wise quant kernel - pass - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - input_scale = None - m, k = input_tensor.shape - n = qweight.shape[1] - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - if input_scale is None: - input_tensor_q, input_scale = per_token_group_quant_fp8( - input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func - ) - else: - # TODO - raise "statci input scale is not supported by triton fp8 block gemm kernel." - m = input_tensor.shape[0] - n = qweight.shape[1] - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - w8a8_block_fp8_matmul( - input_tensor_q, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - return out diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8.py similarity index 51% rename from lightllm/common/quantization/w8a8_quant.py rename to lightllm/common/quantization/w8a8.py index 31004de4e3..98626e1d36 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8.py @@ -1,18 +1,17 @@ import os import torch - -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +import torch.nn.functional as F +from typing import Optional, List, Union, Tuple from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS -import torch.nn.functional as F -from typing import Optional, TYPE_CHECKING -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +from .quantize_method import WeightPack if HAS_LIGHTLLM_KERNEL: @@ -38,16 +37,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -55,35 +55,40 @@ def apply( def method_name(self): return "w8a8-base" + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + raise NotImplementedError("Not implemented") + -@QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) +@QUANTMETHODS.register(["vllm-w8a8", "w8a8"], platform="cuda") class w8a8QuantizationMethod(BaseQuantizationMethod): def __init__(self): super().__init__() self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): - if isinstance(weight, tuple): - return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:] - weight = weight.float() + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 - weight = weight.transpose(0, 1) / scale.reshape(1, -1) - weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + weight = weight / scale.reshape(-1, 1) + weight = torch.round(weight.clamp(min=-127, max=127)).to(dtype=torch.int8) + output.weight.copy_(weight) + output.weight_scale.copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_scale = None - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] @@ -100,48 +105,51 @@ def apply( def method_name(self): return "vllm-w8a8" + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-1, + ) + return mm_param, mm_param_list + -@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) +@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"], platform="cuda") class FP8w8a8QuantizationMethod(BaseQuantizationMethod): def __init__(self): super().__init__() - self.is_moe = False self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): - if self.is_moe: - return self.quantize_moe(weight) + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + qweight, weight_scale = scaled_fp8_quant( - weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale, None - - def quantize_moe(self, weight: torch.Tensor): - num_experts = weight.shape[0] - qweights = [] - weight_scales = [] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) - for i in range(num_experts): - qweight, weight_scale = scaled_fp8_quant( - weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return qweights, weight_scale, None + output.weight.copy_(qweight) + output.weight_scale.copy_(weight_scale.view(-1)) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -160,8 +168,26 @@ def apply( def method_name(self): return "vllm-fp8w8a8" + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-1, + ) + return mm_param, mm_param_list + -@QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) +@QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"], platform="cuda") class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): def __init__(self): super().__init__() @@ -170,21 +196,26 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - raise Exception("Not implemented") + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight.copy_(weight) + output.weight_scale.copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - bias = weight_pack.bias + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape n = qweight.shape[1] @@ -213,3 +244,23 @@ def apply( @property def method_name(self): return "vllm-fp8w8a8-b128" + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + weight_scale_out_dims = [_out_dim // self.block_size for _out_dim in out_dims] + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=weight_scale_out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index d5c96f821a..d606d757c1 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -136,9 +136,9 @@ def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] def new_deepep_group(self, n_routed_experts, hidden_size): - moe_mode = os.getenv("MOE_MODE", "TP") + enable_ep_moe = get_env_start_args().enable_ep_moe num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() - if moe_mode == "TP": + if not enable_ep_moe: self.ep_buffer = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index f4fff116cd..ec1d94458f 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -16,6 +16,4 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.final_norm_weight_.layernorm_forward( - input=input, eps=self.eps_, alloc_func=self.alloc_tensor - ) + return layer_weight.final_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index dfe396ab52..e84069e116 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -15,17 +15,17 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.pre_norm_weight_.layernorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.pre_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 808788f71a..60d584eebd 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -57,14 +57,14 @@ def _token_attention_kernel( def _att_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( + return layer_weight.att_norm_weight_( input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( + return layer_weight.ffn_norm_weight_( input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index 83f7674531..2a9c86b8f3 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -1,25 +1,35 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LayerNormWeight, LMHeadWeight class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.pre_norm_weight_ = NoTpNormWeight( + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.pre_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, bias_name="word_embeddings_layernorm.bias", ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="ln_f.weight", data_type=self.data_type_, bias_name="ln_f.bias", ) self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="word_embeddings.weight", data_type=self.data_type_, ) - self.lm_head_weight_ = self.wte_weight_ + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="word_embeddings.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_, + ) diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 599893655d..568a9c0381 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -108,18 +108,18 @@ def load_hf_weights(self, weights): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_up_weight_name, data_type=self.data_type_, bias_names=self._gate_up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 8695f2de89..e1e435cce4 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -63,8 +63,8 @@ def _bind_func(self): def _bind_ffn(self): if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": + enable_ep_moe = get_env_start_args().enable_ep_moe + if enable_ep_moe: self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) self._tpsp_ffn = self._tpsp_ffn_ep else: @@ -165,20 +165,14 @@ def _get_qkv( q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = layer_weight.q_a_layernorm_.rmsnorm_forward( - input=q, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( @@ -208,10 +202,8 @@ def _tpsp_get_qkv( cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( q_rope, @@ -244,19 +236,13 @@ def _tpsp_get_qkv( position_sin = infer_state.position_sin q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) - q = layer_weight.q_a_layernorm_.rmsnorm_forward( - q, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( q_rope, diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c5a2d33527..783e70e64b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -6,13 +6,11 @@ from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, - COLMMWeight, - NoTpNormWeight, - FusedMoeWeightEP, ROWBMMWeight, - create_tp_moe_wegiht_obj, + COLMMWeight, + RMSNormWeight, + FusedMoeWeight, ) -from functools import partial from ..triton_kernel.weight_dequant import weight_dequant @@ -40,10 +38,16 @@ def _parse_config(self): self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # MOE_MODE 处于 TP 模式下才能使能 enable_fused_shared_experts - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode == "TP" + # enable_fused_shared_experts can only work with tensor parallelism + assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + self.moe_inter = self.network_config_.get("moe_intermediate_size", self.n_inter) + self.q_out_dim = self.num_attention_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim) + self.kv_a_out_dim = self.kv_lora_rank + self.qk_rope_head_dim + self.kv_b_out_dim = self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim) + self.o_in_dim = self.num_attention_heads * self.v_head_dim def _init_weight_names(self): if self.q_lora_rank is None: @@ -60,31 +64,14 @@ def _init_weight(self): self._init_ffn() self._init_norm() - def _load_kb(self, kv_b_proj_): - k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[ - :, : self.qk_nope_head_dim, : - ] - return k_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_kb_scale(self, kv_b_proj_, block_size): - k_b_proj_scale_ = kv_b_proj_.view( - self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size - )[:, : self.qk_nope_head_dim // block_size, :] - return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb(self, kv_b_proj_): - v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[ - :, :, self.qk_nope_head_dim : - ].transpose(0, 1) - return v_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb_scale(self, kv_b_proj_scale_, block_size): - v_b_proj_scale_ = kv_b_proj_scale_.T.view( - self.kv_lora_rank // block_size, - self.num_attention_heads, - self.qk_nope_head_dim * 2 // block_size, - )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) - return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + def _split_kv_b_proj(self, kv_b_proj_): + kv_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank) + k_b_proj_, v_b_proj_ = torch.split(kv_b_proj_, [self.qk_nope_head_dim, self.v_head_dim], dim=-2) + # num_attention_heads x qk_nope_head_dim x kv_lora_rank + k_b_proj_ = k_b_proj_.contiguous().to(kv_b_proj_.dtype) + # num_attention_heads x kv_lora_rank x v_head_dim + v_b_proj_ = v_b_proj_.transpose(1, 2).contiguous().to(kv_b_proj_.dtype) + return k_b_proj_, v_b_proj_ def _rename_shared_experts(self, weights, weight_scale_suffix): # 将共享专家对应的参数,改造为与路由专家一致的权重名称和映射关系。 @@ -108,7 +95,6 @@ def load_hf_weights(self, weights): weight_scale_suffix = None if self.quant_cfg.quantized_weight: weight_scale_suffix = kv_b_quant_method.weight_scale_suffix - if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] # for deepseek_v3, the bmm operator is not quantized @@ -117,21 +103,9 @@ def load_hf_weights(self, weights): kv_b_proj_.cuda(), weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() - weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) - weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) - - if ( - self.quant_cfg.quantized_weight - and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix in weights - ): - kv_b_proj_scale_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix] - block_size = 128 - weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + weight_scale_suffix] = self._load_kb_scale( - kv_b_proj_scale_, block_size - ) - weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + weight_scale_suffix] = self._load_vb_scale( - kv_b_proj_scale_, block_size - ) + k_b_proj_, v_b_proj_ = self._split_kv_b_proj(kv_b_proj_) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = v_b_proj_ # rename the shared experts weight if self.num_fused_shared_experts > 0: @@ -141,116 +115,120 @@ def load_hf_weights(self, weights): def _init_qkvo(self): if self.q_lora_rank is None: self.q_weight_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.q_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_weight", + quant_method=self.get_quant_method("q_weight"), ) self.kv_a_proj_with_mqa_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.kv_a_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_a_proj_with_mqa", + quant_method=self.get_quant_method("kv_a_proj_with_mqa"), tp_rank=0, tp_world_size=1, ) else: self.qkv_a_proj_with_mqa_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.q_lora_rank, self.kv_a_out_dim], weight_names=[ f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", ], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="qkv_a_proj_with_mqa", + quant_method=self.get_quant_method("qkv_a_proj_with_mqa"), tp_rank=0, tp_world_size=1, ) self.q_b_proj_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.q_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_b_proj", + quant_method=self.get_quant_method("q_b_proj"), ) self.k_b_proj_ = ROWBMMWeight( + dim0=self.num_attention_heads, + dim1=self.qk_nope_head_dim, + dim2=self.kv_lora_rank, weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="k_b_proj", + quant_method=None, ) self.v_b_proj_ = ROWBMMWeight( + dim0=self.num_attention_heads, + dim1=self.kv_lora_rank, + dim2=self.v_head_dim, weight_names=f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="v_b_proj", + quant_method=None, ) if self.enable_cc_method: self.cc_kv_b_proj_ = ROWMMWeight( + in_dim=self.kv_lora_rank, + out_dims=[self.kv_b_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="cc_kv_b_proj", + quant_method=self.get_quant_method("cc_kv_b_proj"), ) self.o_weight_ = COLMMWeight( + in_dim=self.o_in_dim, + out_dims=[self.n_embed], weight_names=f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_weight", + quant_method=self.get_quant_method("o_weight"), ) - def _load_mlp(self, mlp_prefix): - moe_mode = os.getenv("MOE_MODE", "TP") - if self.is_moe and moe_mode == "EP": + def _load_mlp(self, mlp_prefix, is_shared_experts=False): + enable_ep_moe = get_env_start_args().enable_ep_moe + mlp_inter = self.moe_inter if is_shared_experts else self.n_inter + if self.is_moe and enable_ep_moe: self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[mlp_inter, mlp_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), tp_rank=0, tp_world_size=1, ) self.down_proj = COLMMWeight( + in_dim=mlp_inter, + out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), tp_rank=0, tp_world_size=1, ) else: self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[mlp_inter, mlp_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=mlp_inter, + out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=None, tp_rank=0, tp_world_size=1, ) @@ -261,54 +239,47 @@ def _init_moe(self): # 专家对应的 gate_up_proj 等weight 参数。当 num_fused_shared_experts # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts") - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["EP", "TP"] - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name=self.e_score_correction_bias_name, - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - num_fused_shared_experts=self.num_fused_shared_experts, - split_inter_size=moe_intermediate_size // self.tp_world_size_, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - elif moe_mode == "EP": - self.experts = FusedMoeWeightEP( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name=self.e_score_correction_bias_name, - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.n_embed, + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_ + hidden_size = self.network_config_["hidden_size"] + + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{self.layer_num_}.input_layernorm.weight", + data_type=self.data_type_, ) - self.ffn_norm_weight_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", + data_type=self.data_type_, ) - self.kv_a_layernorm_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ + self.kv_a_layernorm_ = RMSNormWeight( + dim=self.kv_lora_rank, + weight_name=f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", + data_type=self.data_type_, ) if self.q_lora_rank is not None: - self.q_a_layernorm_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ + self.q_a_layernorm_ = RMSNormWeight( + dim=self.q_lora_rank, + weight_name=f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", + data_type=self.data_type_, ) diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index adb749c40e..7e12245587 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -22,16 +22,8 @@ def _mtp_context_forward( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) @@ -43,16 +35,8 @@ def _mtp_token_forward( tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 1f0815c3db..91c0b2b3f7 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,36 +2,40 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ROWMMWeight, ) +from lightllm.common.quantization import Quantcfg class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): super().__init__(data_type, network_config) - + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], weight_names="model.layers.0.eh_proj.weight", data_type=self.data_type_, - name="eh_proj", + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.enorm.weight", data_type=self.data_type_, - bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.hnorm.weight", data_type=self.data_type_, - bias_name=None, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.shared_head.norm.weight", data_type=self.data_type_, - bias_name=None, ) # 与DeepseekV3模型共享, 不通过 load 加载 diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 0204e292ae..d9ffdb0e31 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -35,7 +35,18 @@ def _init_mem_manager(self): def _init_weights(self, start_layer_index=None): assert start_layer_index is None - super()._init_weights(start_layer_index=0) + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ return diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 1f386625bf..183c4d8d45 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -37,14 +37,10 @@ def _get_qkv( q = q.view(-1, self.tp_q_head_num_, self.head_dim_) k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = layer_weight.q_norm_weight_.rmsnorm_forward( - input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(cache_kv.dtype) + q = layer_weight.q_norm_weight_(input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor).to(cache_kv.dtype) - cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( - input=k.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( + input=k.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(cache_kv.dtype) is_sliding = bool((self.layer_num_ + 1) % self.sliding_window_pattern) @@ -92,7 +88,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input1 = layer_weight.pre_feedforward_layernorm_weight_( input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) @@ -101,10 +97,8 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( - input=ffn_out.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) @@ -127,7 +121,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input1 = layer_weight.pre_feedforward_layernorm_weight_( input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) @@ -136,10 +130,8 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( - input=ffn_out.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 858937d8c1..7ae0fbcca3 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -5,14 +5,19 @@ class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="language_model.model.embed_tokens.weight", data_type=self.data_type_, ) self.lm_head_weight_ = self.wte_weight_ self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, weight_name="language_model.model.norm.weight", data_type=self.data_type_, bias_name=None, diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index e7808c412c..a4340a17a7 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -25,51 +25,63 @@ def _init_weight_names(self): def _init_ffn(self): self.gate_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_weight_name, data_type=self.data_type_, bias_names=self._gate_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_proj", + quant_method=self.get_quant_method("gate_proj"), ) self.up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._up_weight_name, data_type=self.data_type_, bias_names=self._up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="up_proj", + quant_method=self.get_quant_method("up_proj"), ) super()._init_ffn() def _init_qkv(self): + kv_out_dim = self.k_head_num_ * self.head_dim self.k_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[kv_out_dim], weight_names=self._k_weight_name, data_type=self.data_type_, bias_names=self._k_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_proj", + quant_method=self.get_quant_method("k_proj"), ) self.v_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[kv_out_dim], weight_names=self._v_weight_name, data_type=self.data_type_, bias_names=self._v_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="v_proj", + quant_method=self.get_quant_method("v_proj"), ) super()._init_qkv() def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NoTpGEMMANormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NoTpGEMMANormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + + self.k_norm_weight_ = NoTpGEMMANormWeight( + dim=self.head_dim_, weight_name=self._k_norm_weight_name, data_type=self.data_type_, bias_name=None + ) + self.q_norm_weight_ = NoTpGEMMANormWeight( + dim=self.head_dim_, weight_name=self._q_norm_weight_name, data_type=self.data_type_, bias_name=None + ) self.pre_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( - self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None + dim=self.n_embed, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + bias_name=None, ) self.post_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( - self._post_feedforward_layernorm_name, self.data_type_, bias_name=None + dim=self.n_embed, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + bias_name=None, ) def load_hf_weights(self, weights): diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index 468d471d2c..a25c5af4c9 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -13,8 +13,6 @@ class Gemma_2bPreLayerInfer(PreLayerInferTpl): def __init__(self, network_config): super().__init__(network_config) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) self.normfactor = network_config["hidden_size"] ** 0.5 return @@ -22,20 +20,14 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight) return input * self.normfactor def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding( - input_ids=input_ids, - alloc_func=self.alloc_tensor, - ) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding( - input_ids=input_ids, - alloc_func=self.alloc_tensor, - ) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index 6e052caa63..23ae50f096 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -5,14 +5,19 @@ class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) self.lm_head_weight_ = self.wte_weight_ self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name=None, diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 9102ce6775..19951b9900 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -11,23 +11,30 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): return def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + kv_out_dim = self.k_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) self.kv_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim, kv_out_dim], weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_norm(self): - self.att_norm_weight_ = NoTpGEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = NoTpGEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NoTpGEMMANormWeight( + dim=self.n_embed, weight_name=self._att_norm_weight_name, data_type=self.data_type_, bias_name=None + ) + self.ffn_norm_weight_ = NoTpGEMMANormWeight( + dim=self.n_embed, weight_name=self._ffn_norm_weight_name, data_type=self.data_type_, bias_name=None + ) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index c5c14b08e6..7c8c30940e 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -2,11 +2,14 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gpt_oss_fused_moe_weight_tp import ( + GPTOSSFusedMoeWeightTP, +) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -23,17 +26,18 @@ def __init__( return def _init_moe(self): - moe_mode = os.getenv("MOE_MODE", "TP") + enable_ep_moe = get_env_start_args().enable_ep_moe moe_intermediate_size = self.network_config_["intermediate_size"] n_routed_experts = self.network_config_["num_local_experts"] - assert moe_mode in ["TP"], "For now, GPT-OSS type model only support MOE TP mode." + assert not enable_ep_moe, "For now, GPT-OSS type model only support MOE TP mode." self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[n_routed_experts], weight_names=self._router_weight_name, data_type=self.data_type_, - layer_num=self.layer_num_, bias_names=self._router_bias_name, - name="moe_gate", + quant_method=self.get_quant_method("moe_gate"), tp_rank=0, tp_world_size=1, ) @@ -44,13 +48,13 @@ def _init_moe(self): e_score_correction_bias_name="", weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", n_routed_experts=n_routed_experts, - split_inter_size=moe_intermediate_size // self.tp_world_size_, + hidden_size=self.n_embed, + moe_intermediate_size=moe_intermediate_size, data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - world_size=self.tp_world_size_, # diff with FusedMoeWeightTP - quant_cfg=self.quant_cfg, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), num_fused_shared_experts=0, + layer_num=self.layer_num_, + network_config=self.network_config_, ) def _init_weight_names(self): @@ -68,6 +72,7 @@ def _init_weight(self): super()._init_weight() self.attn_sinks = TpAttSinkWeight( + all_q_head_num=self.q_head_num_, weight_name=f"model.layers.{self.layer_num_}.self_attn.sinks", data_type=torch.bfloat16, ) diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index dc5f2abdfe..9e9561eb24 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,9 +2,10 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry - from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class +from lightllm.common.basemodel.attention import BaseAttBackend logger = init_logger(__name__) @@ -19,9 +20,11 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) - assert ( - get_env_start_args().llm_prefill_att_backend[0] == "fa3" - ), "For now GPT-OSS type model only support flashattention-3" - assert ( - get_env_start_args().llm_decode_att_backend[0] == "fa3" - ), "For now GPT-OSS type model only support flashattention-3" + + def _init_att_backend(self): + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index 3ed7004c12..3bb526c79b 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -1,14 +1,26 @@ from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) - self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) - - self.final_norm_weight_ = NoTpNormWeight( + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.tok_embeddings.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="output.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py index a05e977f16..e528ee9b53 100755 --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -21,6 +21,10 @@ def load_hf_weights(self, weights): del weights[qkv_weight_name] super().load_hf_weights(weights) + def _parse_config(self): + super()._parse_config() + self.n_kv_head = self.network_config_["num_key_value_heads"] + def _init_weight_names(self): super()._init_weight_names() self._o_weight_name = f"model.layers.{self.layer_num_}.attention.wo.weight" diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 59caf40d6b..b526192125 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,12 +1,15 @@ -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight, ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, RMSNormWeight, ROWMMWeight class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.tok_embeddings.weight", data_type=self.data_type_, ) @@ -17,7 +20,8 @@ def __init__(self, data_type, network_config): tp_rank=0, tp_world_size=1, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 8bc10d623c..7714164151 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -19,7 +19,7 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.final_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo): embed_dim_ = input_embdings.shape[1] @@ -66,7 +66,7 @@ def token_forward( input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) + logic_batch = layer_weight.lm_head_weight_(input=last_input, alloc_func=self.alloc_tensor) last_input = None vocab_size = layer_weight.lm_head_weight_.vocab_size if self.tp_world_size_ == 1: diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index f4f150b173..63a2fe4d14 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -15,13 +15,13 @@ def __init__(self, network_config): return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 2a9a543196..dc6f10be59 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -4,7 +4,7 @@ from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor @@ -69,16 +69,12 @@ def _token_attention_kernel( def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.att_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.rmsnorm_forward( - input=input, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + return layer_weight.ffn_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 7e9ff41673..8efa36cf80 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -1,27 +1,30 @@ from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) - - self.final_norm_weight_ = NoTpNormWeight( + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, - bias_name=None, ) return diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 197116d99c..0566c9f1c6 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight, KVROWNMMWeight class LlamaTransformerLayerWeight(TransformerLayerWeight): @@ -23,11 +23,15 @@ def _init_weight(self): self._init_norm() def _parse_config(self): - self.n_embed = self.network_config_["hidden_size"] self.n_head = self.network_config_["num_attention_heads"] + self.q_head_num_ = self.network_config_["num_attention_heads"] + self.k_head_num_ = self.network_config_["num_key_value_heads"] + self.v_head_num_ = self.k_head_num_ + self.o_head_num_ = self.q_head_num_ + head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] + self.head_dim = self.network_config_.get("head_dim", head_dim) + self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] - self.n_kv_head = self.network_config_["num_key_value_heads"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" @@ -56,55 +60,65 @@ def _init_weight_names(self): self._ffn_norm_bias_name = None def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) - self.kv_proj = ROWMMWeight( + self.kv_proj = KVROWNMMWeight( + in_dim=in_dim, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_o(self): + in_dim = self.o_head_num_ * self.head_dim + out_dim = self.n_embed self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, ) - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, ) diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index dbe9b61c85..96a15d18a6 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -19,22 +19,10 @@ def _mtp_context_forward( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) - tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + tgt_embdings = layer_weight.final_norm_weight_(input=tgt_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) @@ -47,22 +35,10 @@ def _mtp_token_forward( tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) - tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + tgt_embdings = layer_weight.final_norm_weight_(input=tgt_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index c9032f6fee..5ec5bf7c18 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,35 +2,38 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ROWMMWeight, ) +from lightllm.common.quantization import Quantcfg class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): super().__init__(data_type, network_config) - + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], weight_names="mtp.eh_proj.weight", data_type=self.data_type_, - layer_num=0, - name="eh_proj", + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="mtp.enorm.weight", data_type=self.data_type_, - bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="mtp.hnorm.weight", data_type=self.data_type_, - bias_name=None, ) self.wte_weight_: EmbeddingWeight = None self.lm_head_weight_: LMHeadWeight = None - self.final_norm_weight_: NoTpNormWeight = None + self.final_norm_weight_: RMSNormWeight = None return diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 08f280b06c..8d3f94f8f3 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight class MistralMTPTransformerLayerWeight(TransformerLayerWeight): @@ -7,6 +7,10 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__(layer_num, data_type, network_config, quant_cfg) return + def _parse_config(self): + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + def _init_weight_names(self): self._gate_weight_name = f"mtp.layers.{self.layer_num_}.mlp.gate_proj.weight" self._gate_bias_name = None @@ -24,23 +28,26 @@ def _init_weight(self): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, ) diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 0132db80f5..7c64625ca8 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -48,9 +48,19 @@ def _init_mem_manager(self): def _init_weights(self, start_layer_index=None): assert start_layer_index is None - self.config["n_layer"] = 1 - super()._init_weights(start_layer_index=0) + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 39e28d4655..51c62fd4cb 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -2,7 +2,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -31,36 +32,29 @@ def _init_ffn(self): def _init_moe(self): inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.tp_world_size_ self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_routed_experts], weight_names=self.moe_gate_weight_name, data_type=self.data_type_, bias_names=self.moe_gate_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=self.get_quant_method("moe_gate"), tp_rank=0, tp_world_size=1, # no tensor parallelism ) - - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["TP"], f"Unsupported moe mode: {moe_mode}" - - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="w1", - down_proj_name="w2", - up_proj_name="w3", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts", - n_routed_experts=self.n_routed_experts, - split_inter_size=split_inter_size, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - num_fused_shared_experts=0, - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + assert get_env_start_args().enable_ep_moe, "Mixtral only support tp mode." + self.experts = FusedMoeWeight( + gate_proj_name="w1", + down_proj_name="w2", + up_proj_name="w3", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.n_embed, + moe_intermediate_size=inter_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index bf9282a979..52d1a54f59 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -1,21 +1,26 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="transformer.wte.weight", data_type=self.data_type_, ) self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="lm_head.weight", data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="transformer.ln_f.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 9c3e2cb3a8..fe6a5a2d49 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -10,42 +10,3 @@ def _init_weight_names(self): self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" - - def _parse_config(self): - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) - self.tp_v_head_num_ = self.tp_k_head_num_ - self.tp_o_head_num_ = self.tp_q_head_num_ - head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] - self.head_dim = self.network_config_.get("head_dim", head_dim) - assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 - - def _repeat_weight(self, name, weights): - # for tp_world_size_ > num_key_value_heads - if name not in weights: - return - - tensor = weights[name] - num_kv_heads = self.network_config_["num_key_value_heads"] - repeat_size = (self.tp_k_head_num_ * self.tp_world_size_) // num_kv_heads - - if tensor.ndim == 1: - # Bias (1D tensor) - tensor = tensor.reshape(num_kv_heads, -1).unsqueeze(1).repeat(1, repeat_size, 1).reshape(-1) - else: - # Weight (2D tensor) - tensor = ( - tensor.reshape(num_kv_heads, -1, tensor.shape[-1]) - .unsqueeze(1) - .repeat(1, repeat_size, 1, 1) - .reshape(-1, tensor.shape[-1]) - ) - weights[name] = tensor - - def load_hf_weights(self, weights): - self._repeat_weight(self._k_weight_name, weights) - self._repeat_weight(self._v_weight_name, weights) - if self._k_bias_name is not None and self._v_bias_name is not None: - self._repeat_weight(self._k_bias_name, weights) - self._repeat_weight(self._v_bias_name, weights) - return super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 5f0c91287d..82331f8fb8 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -4,9 +4,6 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward -from functools import partial from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -27,14 +24,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 90b7810adf..7d2163f283 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NoTpNormWeight, + QKRMSNORMWeight, ) @@ -19,6 +19,13 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - - self.q_norm_weight_ = NoTpNormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NoTpNormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.q_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._q_norm_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._k_norm_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index c85c423c29..71b16cb34b 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -14,6 +14,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_global_world_size from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -41,8 +42,8 @@ def _bind_func(self): def _bind_ffn(self): if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": + enable_ep_moe = get_env_start_args().enable_ep_moe + if enable_ep_moe: self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn_edp, self) self._tpsp_ffn = self._tpsp_ffn_ep else: @@ -60,20 +61,13 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - layer_weight.q_norm_weight_.rmsnorm_forward( - q.view(-1, self.head_dim_), + cache_kv = layer_weight.kv_proj.mm(input) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, - out=q.view(-1, self.head_dim_), ) - - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( - input=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - eps=self.eps_, - alloc_func=self.alloc_tensor, - ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) - + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, : self.tp_k_head_num_, :], @@ -98,19 +92,13 @@ def _tpsp_get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - layer_weight.q_norm_weight_.rmsnorm_forward( - q.view(-1, self.head_dim_), + cache_kv = layer_weight.kv_proj.mm(input) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, - out=q.view(-1, self.head_dim_), ) - - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( - cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), - eps=self.eps_, - alloc_func=self.alloc_tensor, - ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 486f4d6966..a889609d7f 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import os from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): @@ -32,15 +32,6 @@ def _init_weight_names(self): self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" self._ffn_norm_bias_name = None - def load_hf_weights(self, weights): - kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") - if self.quant_cfg.quantized_weight: - _k_scale_weight_name = self._k_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix) - self._repeat_weight(_k_scale_weight_name, weights) - _v_scale_weight_name = self._v_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix) - self._repeat_weight(_v_scale_weight_name, weights) - return super().load_hf_weights(weights) - def _init_weight(self): self._init_qkv() self._init_o() @@ -53,42 +44,25 @@ def _init_weight(self): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=None, tp_rank=0, tp_world_size=1, ) - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["EP", "TP"] - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - split_inter_size=moe_intermediate_size // self.tp_world_size_, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - num_fused_shared_experts=0, - ) - elif moe_mode == "EP": - self.experts = FusedMoeWeightEP( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 8ba95c1386..3038a4d074 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -4,33 +4,37 @@ EmbeddingWeight, ROWMMWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ) +from lightllm.common.quantization import Quantcfg class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): super().__init__(data_type, network_config) - + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], weight_names="model.layers.0.proj.weight", + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), data_type=self.data_type_, - name="eh_proj", tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.norm_after_embedding.weight", data_type=self.data_type_, - bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.norm_before_output.weight", data_type=self.data_type_, - bias_name=None, ) # 与Qwen3MOE模型共享 self.wte_weight_: EmbeddingWeight = None self.lm_head_weight_: LMHeadWeight = None - self.final_norm_weight_: NoTpNormWeight = None + self.final_norm_weight_: RMSNormWeight = None return diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 095afecd91..12bb969801 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -16,6 +16,9 @@ def _init_weight(self): self._init_ffn() def _init_norm(self): - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 72aadbda80..9f83832a7e 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -41,7 +41,18 @@ def _init_mem_manager(self): def _init_weights(self, start_layer_index=None): assert start_layer_index is None mtp_index = len(self.mtp_previous_draft_models) - super()._init_weights(start_layer_index=mtp_index) + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(mtp_index, mtp_index + self.config["n_layer"]) + ] self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index d1c51365a1..d34babaabe 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -9,7 +9,6 @@ from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -30,14 +29,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 328cc0a625..4ccc6da372 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -5,7 +5,6 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features @@ -26,14 +25,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 52a982f495..0ba06c6aed 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -1,24 +1,28 @@ -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.language_model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py index 48ddf52089..83c05ba264 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -1,6 +1,5 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index f908dbdd3b..55848ce66a 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -53,17 +53,13 @@ def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, t def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.att_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.ffn_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 3d044eeb56..885c7ead7d 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -1,10 +1,13 @@ -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight, NoTpNormWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import LayerNormWeight class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.final_norm_weight_ = NoTpNormWeight( + hidden_size = network_config["hidden_size"] + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name="model.norm.bias", diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index 6b88c066ee..b3cd083c30 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -14,24 +14,18 @@ def __init__(self, network_config): self.layer_norm_eps_ = network_config["layer_norm_epsilon"] def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = layer_weight.wpe_weight_.embedding( - input_ids=infer_state.position_ids, - alloc_func=self.alloc_tensor, - ) + position_embeds = layer_weight.wpe_weight_(input_ids=infer_state.position_ids, alloc_func=self.alloc_tensor) return input_embdings.add_(position_embeds) def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = layer_weight.wpe_weight_.embedding( - input_ids=infer_state.position_ids, - alloc_func=self.alloc_tensor, - ) + position_embeds = layer_weight.wpe_weight_(input_ids=infer_state.position_ids, alloc_func=self.alloc_tensor) return input_embdings.add_(position_embeds) diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 329a0245f0..a258480f65 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,7 @@ from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, - NoTpNormWeight, + LayerNormWeight, NoTpPosEmbeddingWeight, LMHeadWeight, ) @@ -11,21 +11,32 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + def _create_weight(self): + hidden_size = self.network_config["hidden_size"] + vocab_size = self.network_config["vocab_size"] + max_position_embeddings = self.network_config["max_position_embeddings"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="transformer.wte.weight", data_type=self.data_type_, ) self.wpe_weight_ = NoTpPosEmbeddingWeight( + dim=hidden_size, + max_position_embeddings=max_position_embeddings, weight_name="transformer.wpe.weight", data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="transformer.ln_f.weight", bias_name="transformer.ln_f.bias", data_type=self.data_type_, ) self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="lm_head.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 41f24f79cb..4adbe4f5e6 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -51,18 +51,18 @@ def _init_weight_names(self): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_up_weight_name, data_type=self.data_type_, bias_names=self._gate_up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 09e3299eb6..3e32682ecb 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -11,19 +11,15 @@ def __init__(self, layer_num, network_config): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.att_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.ffn_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn( diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 6ee1885372..cc256c442d 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,26 +1,29 @@ -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, LayerNormWeight class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name="model.norm.bias", diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index 53342e221f..7370c69530 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -28,18 +28,18 @@ def _init_weight_names(self): def _init_ffn(self): self.up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._up_weight_name, data_type=self.data_type_, bias_names=self._up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="up_proj", + quant_method=self.get_quant_method("up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py index fa4a87f158..0eb0c46040 100644 --- a/lightllm/models/vit/layer_infer/post_layer_infer.py +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -15,6 +15,7 @@ def __init__(self, network_config): self.network_config_ = network_config self.llm_hidden_size = network_config["llm_hidden_size"] self.downsample_ratio = network_config["downsample_ratio"] + self.eps_ = network_config["layer_norm_eps"] return def pixel_shuffle(self, x, scale_factor=0.5): @@ -33,25 +34,12 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight): h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) - vit_embeds_norm = torch.nn.functional.layer_norm( - vit_embeds, - (vit_embeds.shape[-1],), - weight=layer_weight.layernorm_weight_, - bias=layer_weight.layernorm_bias_, - ) - - vit_embeds_1 = torch.addmm( - layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_ - ) + vit_embeds_norm = layer_weight.layernorm_weight_(input=vit_embeds, eps=self.eps_) + vit_embeds_1 = layer_weight.mlp1_1_.mm(vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1])) vit_embeds_gelu = gelu_fwd(vit_embeds_1, use_custom_tensor_mananger=True) - vit_embeds_out = torch.addmm( - layer_weight.mlp1_3_bias_, - vit_embeds_gelu.view(-1, self.llm_hidden_size // self.tp_world_size_), - layer_weight.mlp1_3_weight_, - beta=1.0 / self.tp_world_size_, - ) + vit_embeds_out = layer_weight.mlp1_3_.mm(vit_embeds_gelu.view(-1, self.llm_hidden_size // self.tp_world_size_)) if self.tp_world_size_ == 1: return vit_embeds_out.view(batch_size, -1, self.llm_hidden_size) diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index e2bed10361..73eb0b46ac 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -4,15 +4,60 @@ import torch.nn.functional as F from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.basemodel.layer_weights.meta_weights import LayerNormWeight, COLMMWeight, ROWMMWeight +from lightllm.common.quantization import Quantcfg class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config): + def __init__(self, data_type, network_config, quant_cfg): super().__init__(data_type, network_config) self.embed_dim = self.network_config_["hidden_size"] self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self.downsample_ratio = self.network_config_["downsample_ratio"] + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.quant_cfg: Quantcfg = quant_cfg + self._create_weight() + return + + def _create_weight(self): + split_indexes = np.linspace(0, self.embed_dim, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_embed_dim = split_end - split_start + + # Pre-allocate memory for vision model weights + self.class_embedding = torch.empty((1, 1, split_embed_dim), dtype=self.data_type_).cuda() + self.position_embedding = torch.empty((1, self.num_positions, split_embed_dim), dtype=self.data_type_).cuda() + self.patch_embedding_weight_ = torch.empty( + (split_embed_dim, 3, self.patch_size, self.patch_size), dtype=self.data_type_ + ).cuda() + self.patch_embedding_bias_ = torch.empty(split_embed_dim, dtype=self.data_type_).cuda() + + self.layernorm_weight_ = LayerNormWeight( + dim=self.embed_dim * int(1 / self.downsample_ratio) ** 2, + weight_name="mlp1.0.weight", + data_type=self.data_type_, + bias_name="mlp1.0.bias", + ) + self.mlp1_1_ = ROWMMWeight( + in_dim=self.embed_dim * int(1 / self.downsample_ratio) ** 2, + out_dims=[self.llm_hidden_size], + weight_names=["mlp1.1.weight"], + data_type=self.data_type_, + bias_names=["mlp1.1.bias"], + quant_method=self.quant_cfg.get_quant_method(-1, "mlp1_1"), + ) + self.mlp1_3_ = COLMMWeight( + in_dim=self.llm_hidden_size, + out_dims=[self.llm_hidden_size], + weight_names=["mlp1.3.weight"], + data_type=self.data_type_, + bias_names=["mlp1.3.bias"], + quant_method=self.quant_cfg.get_quant_method(-1, "mlp1_3"), + ) return def _cuda(self, cpu_tensor): @@ -36,45 +81,24 @@ def _get_pos_embed(self, H, W): return pos_embed def load_hf_weights(self, weights): + super().load_hf_weights(weights) split_indexes = np.linspace(0, self.embed_dim, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "vision_model.embeddings.class_embedding" in weights: - self.class_embedding = self._cuda( - weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] - ) + self.class_embedding.copy_(weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end]) if "vision_model.embeddings.position_embedding" in weights: - self.position_embedding = self._cuda( + self.position_embedding.copy_( weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.patch_embedding.weight" in weights: - self.patch_embedding_weight_ = self._cuda( + self.patch_embedding_weight_.copy_( weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] ) if "vision_model.embeddings.patch_embedding.bias" in weights: - self.patch_embedding_bias_ = self._cuda( + self.patch_embedding_bias_.copy_( weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] ) - - if "mlp1.0.weight" in weights: - self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) - if "mlp1.0.bias" in weights: - self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) - - split_indexes = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] - - if "mlp1.1.weight" in weights: - self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() - if "mlp1.1.bias" in weights: - self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) - - if "mlp1.3.weight" in weights: - self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() - if "mlp1.3.bias" in weights: - self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) - return def verify_load(self): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index dffcc16fe8..198b3022be 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -7,8 +7,9 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NoTpNormWeight, - TpVitPadNormWeight, + RMSNormWeight, + LayerNormWeight, + TpRMSNormWeight, ) from lightllm.utils.dist_utils import get_current_device_id @@ -29,6 +30,9 @@ def _parse_config(self): self.qkv_bias = self.network_config_.get("qkv_bias", True) self.layer_norm_eps = self.network_config_.get("layer_norm_eps", 1e-6) self.norm_type = self.network_config_.get("norm_type", "layer_norm") + self.n_embed = self.network_config_["hidden_size"] + self.padding_hidden_size + mlp_ratio = self.network_config_.get("mlp_ratio", 4) + self.n_inter = self.network_config_.get("intermediate_size", int(self.n_embed * mlp_ratio)) def _init_weight_names(self): self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" @@ -81,54 +85,85 @@ def _init_weight(self): def _init_qkv(self): self.qkv_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_embed, self.n_embed, self.n_embed], weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="qkv_proj", + quant_method=self.get_quant_method("qkv_proj"), ) def _init_o(self): self.o_proj = COLMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_embed], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): self.ffn_1_proj_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self.fc1_weight_name_, data_type=self.data_type_, bias_names=self.fc1_bias_name_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="ffn_1_proj", + quant_method=self.get_quant_method("ffn_1_proj"), ) self.ffn_2_proj_ = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self.fc2_weight_name_, data_type=self.data_type_, bias_names=self.fc2_bias_name_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="ffn_2_proj", + quant_method=self.get_quant_method("ffn_2_proj"), ) def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name - ) - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name - ) + hidden_size = self.network_config_["hidden_size"] + if self.norm_type == "rms_norm": + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + else: + self.att_norm_weight_ = LayerNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + bias_name=self._att_norm_bias_name, + ) + self.ffn_norm_weight_ = LayerNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + bias_name=self._ffn_norm_bias_name, + ) if self.qk_norm: head_num = self.network_config_["num_attention_heads"] - self.q_norm_weight_ = TpVitPadNormWeight(self._q_norm_weight_name, self.data_type_, head_num=head_num) - self.k_norm_weight_ = TpVitPadNormWeight(self._k_norm_weight_name, self.data_type_, head_num=head_num) + head_dim = self.network_config_["hidden_size"] // head_num + head_dim = self.network_config_.get("head_dim", head_dim) + self.q_norm_weight_ = TpRMSNormWeight( + head_num=head_num, + head_dim=head_dim, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = TpRMSNormWeight( + head_num=head_num, + head_dim=head_dim, + weight_name=self._k_norm_weight_name, + data_type=self.data_type_, + ) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 9c2bc42426..13f8e2827f 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -111,7 +111,9 @@ def _padding_hidden_size(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) self.trans_layers_weight = [ self.transformer_weight_class( i, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc388223..e49b0cc67c 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -465,10 +465,8 @@ def make_argument_parser() -> argparse.ArgumentParser: "--quant_type", type=str, default="none", - help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16 - | ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16 - | vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 - | triton-fp8w8a8-block128""", + help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 + | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin""", ) parser.add_argument( "--quant_cfg", @@ -481,9 +479,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--vit_quant_type", type=str, default="none", - help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16 - | ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16 - | vllm-w8a8 | vllm-fp8w8a8""", + help="""Quantization method for ViT: vllm-w8a8 | vllm-fp8w8a8""", ) parser.add_argument( "--vit_quant_cfg", @@ -520,6 +516,11 @@ def make_argument_parser() -> argparse.ArgumentParser: " Therefore, it is recommended to set this parameter according to actual needs." ), ) + parser.add_argument( + "--enable_ep_moe", + action="store_true", + help="""Whether to enable ep moe for deepseekv3 model.""", + ) parser.add_argument( "--ep_redundancy_expert_config_path", type=str, @@ -534,7 +535,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """, + help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", ) parser.add_argument( "--mtp_mode", @@ -608,4 +609,24 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--hardware_platform", + type=str, + default="cuda", + choices=["cuda", "musa"], + help="""Hardware platform: cuda | musa""", + ) + parser.add_argument( + "--enable_torch_fallback", + action="store_true", + help="""Whether to enable torch naive implementation for the op. + If the op is not implemented for the platform, it will use torch naive implementation.""", + ) + parser.add_argument( + "--enable_triton_fallback", + action="store_true", + help="""Whether to enable triton implementation for the op. + If the op is not implemented for the platform and the hardware support triton, + it will use triton implementation.""", + ) return parser diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 239cebfdd1..059cd739f1 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -127,6 +127,7 @@ class StartArgs: penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} ) + enable_ep_moe: bool = field(default=False) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( diff --git a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py index 811d39a729..596eca4f24 100644 --- a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py @@ -8,10 +8,10 @@ import json from typing import List from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep_redundancy import ( +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.ep_redundancy import ( FusedMoeWeightEPAutoRedundancy, ) -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep import FusedMoeWeightEP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight from lightllm.utils.envs_utils import get_env_start_args, get_redundancy_expert_update_interval from lightllm.utils.envs_utils import get_redundancy_expert_update_max_load_count from lightllm.utils.envs_utils import get_redundancy_expert_num @@ -28,7 +28,7 @@ def __init__(self, model: TpPartBaseModel): self.model = model self.ep_fused_moeweights: List[FusedMoeWeightEPAutoRedundancy] = [] for layer in self.model.trans_layers_weight: - ep_weights = self._find_members_of_class(layer, FusedMoeWeightEP) + ep_weights = self._find_members_of_class(layer, FusedMoeWeight) assert len(ep_weights) <= 1 self.ep_fused_moeweights.extend([FusedMoeWeightEPAutoRedundancy(e) for e in ep_weights]) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 09d7a680fe..a1ed6ed950 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -3,6 +3,8 @@ import torch import shutil import subprocess +from enum import Enum +from typing import Optional from functools import lru_cache from lightllm.utils.log_utils import init_logger @@ -284,3 +286,42 @@ def is_5090_gpu() -> bool: return False except: return False + + +class Platform(Enum): + """hardware platform enum""" + + CUDA = "cuda" + ASCEND = "ascend" # ascend + CAMBRICON = "cambricon" # cambricon + MUSA = "musa" # musa + ROCM = "rocm" # rocm + CPU = "cpu" # cpu + + +# 目前仅支持cuda 和 musa +def get_platform(platform_name: Optional[str] = None) -> Platform: + """ + get hardware platform. + + Args: + platform_name: platform name (cuda, ascend, cambricon, musa, rocm, cpu) + + Returns: + Platform: platform enum value + """ + assert platform_name in ["cuda", "musa"], f"Only support cuda and musa now, but got {platform_name}" + platform_name = platform_name.lower() + platform_map = { + "cuda": Platform.CUDA, + "ascend": Platform.ASCEND, + "cambricon": Platform.CAMBRICON, + "musa": Platform.MUSA, + "rocm": Platform.ROCM, + "cpu": Platform.CPU, + } + + platform = platform_map.get(platform_name) + if platform is None: + raise ValueError(f"Unknown platform name: {platform_name}") + return platform diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index e00af27139..8ed44a2753 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -99,7 +99,6 @@ sh multi_pd_master/pd_decode.sh ### Environment Variables - `LOADWORKER`: Model loading thread count, recommended 8-18 -- `MOE_MODE`: Expert parallelism mode, set to EP to enable expert parallelism - `DISABLE_KV_TRANS_USE_P2P`: Disable P2P communication optimization to transfer kv data - `CUDA_VISIBLE_DEVICES`: Specify GPU devices to use @@ -108,6 +107,7 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree +- `--enable_ep_mode`: enable expert parallel - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index cd72e6cfc6..2cc6b03c90 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -2,14 +2,14 @@ # nccl_host: the ip of the nccl host # sh multi_node_ep_node0.sh export nccl_host=$1 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ ---nccl_port 2732 +--nccl_port 2732 --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap #--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index 17b878a1b6..cc920b0b05 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -2,14 +2,14 @@ # nccl_host: the ip of the nccl host # sh multi_node_ep_node1.sh export nccl_host=$1 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ ---nccl_port 2732 +--nccl_port 2732 --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap #--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index 41ad525514..45f6c0c011 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -5,7 +5,7 @@ export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -15,6 +15,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 2732 \ --disable_cudagraph \ --config_server_host $config_server_host \ ---config_server_port 60088 +--config_server_port 60088 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index e143c34ece..21d2ebaa3a 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -1,8 +1,9 @@ # H200 single node deepseek R1 dpep mode -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ ---dp 8 +--dp 8 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index 9601d51174..dac7a6dac6 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -13,6 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ +--enable_ep_moe \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 931fee8626..4b3fd0bc4e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -10,7 +10,7 @@ export UCX_LOG_LEVEL=info export UCX_TLS=rc,cuda,gdr_copy nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_decode" \ --tp 8 \ @@ -18,6 +18,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ +--enable_ep_moe \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index 6363207cb7..f415919f90 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -11,7 +11,7 @@ export UCX_TLS=rc,cuda,gdr_copy export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_prefill" \ --tp 8 \ @@ -19,6 +19,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ +--enable_ep_moe \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index 0c1bd26590..6bde9ef32c 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -15,6 +15,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 2732 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ ---pd_master_port 60011 +--pd_master_port 60011 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ No newline at end of file diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 9a613f6f7d..9c08cfc1a4 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -2,7 +2,12 @@ import time import pytest import triton -from lightllm.common.fused_moe.grouped_fused_moe import moe_align, moe_align1, moe_align2, grouped_matmul +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import ( + moe_align, + moe_align1, + moe_align2, + grouped_matmul, +) from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe_speed.py b/unit_tests/common/fused_moe/test_grouped_fused_moe_speed.py index 03beccdf98..769002517d 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe_speed.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe_speed.py @@ -1,7 +1,7 @@ import torch import time import pytest -from lightllm.common.fused_moe.grouped_fused_moe import moe_align, moe_align1, grouped_matmul +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import moe_align, moe_align1, grouped_matmul from lightllm.utils.log_utils import init_logger seed = 42 diff --git a/unit_tests/common/fused_moe/test_grouped_topk.py b/unit_tests/common/fused_moe/test_grouped_topk.py index 37c3fabc75..432e133163 100755 --- a/unit_tests/common/fused_moe/test_grouped_topk.py +++ b/unit_tests/common/fused_moe/test_grouped_topk.py @@ -2,8 +2,8 @@ import time import pytest import numpy as np -from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk -from lightllm.common.fused_moe.topk_select import biased_grouped_topk as grouped_topk +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_topk import triton_grouped_topk +from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import biased_grouped_topk as grouped_topk from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index 671805a3d2..29aed2a70e 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -14,9 +14,11 @@ def is_fp8_native_supported(): pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True) import random -from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( + silu_and_mul_masked_post_quant_fwd, +) +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -37,7 +39,7 @@ def is_fp8_native_supported(): ) def test_silu_and_mul_masked(expert_num, token_num, hidden_dim): quant_group_size = 128 - in_tensor = torch.randn((expert_num, token_num, hidden_dim), dtype=torch.float16, device="cuda") + in_tensor = torch.randn((expert_num, token_num, hidden_dim), dtype=torch.bfloat16, device="cuda") out_tensor = torch.empty((expert_num, token_num, hidden_dim // 2), dtype=torch.float8_e4m3fn, device="cuda") out_scale_tensor = torch.randn( (expert_num, token_num, hidden_dim // 2 // quant_group_size), dtype=torch.float32, device="cuda" diff --git a/unit_tests/common/fused_moe/test_softmax_topk.py b/unit_tests/common/fused_moe/test_softmax_topk.py index 6252dfa8c3..7c3e483df8 100755 --- a/unit_tests/common/fused_moe/test_softmax_topk.py +++ b/unit_tests/common/fused_moe/test_softmax_topk.py @@ -2,7 +2,7 @@ import time import pytest import numpy as np -from lightllm.common.fused_moe.softmax_topk import softmax_topk +from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py index 2c0b7bf76e..e6a0d52c72 100644 --- a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py +++ b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py @@ -1,7 +1,7 @@ import torch import pytest import torch.nn.functional as F -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token def is_fp8_native_supported():