From 36c4ea0d3851ed8c3a72284d6eb6683082efe65c Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Sat, 13 Jun 2026 20:16:13 +0800 Subject: [PATCH] feat: qwen3.5 perf opt --- .../basemodel/attention/create_utils.py | 4 +- lightllm/common/basemodel/attention/fa3/fp.py | 35 ++- .../basemodel/attention/flashinfer/fp.py | 106 ++++++- .../basemodel/attention/flashinfer/mla.py | 8 +- lightllm/common/basemodel/basemodel.py | 17 +- lightllm/common/basemodel/batch_objs.py | 3 + lightllm/common/basemodel/infer_struct.py | 2 + .../transformer_layer_infer_template.py | 15 +- .../fused_moe/fused_moe_weight.py | 4 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/deepgemm_impl.py | 2 + .../fused_moe/impl/marlin_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 8 + .../layer_weights/meta_weights/norm_weight.py | 17 +- .../fused_moe/grouped_fused_moe.py | 48 ++- .../triton_kernel/fused_moe/moe_sum_reduce.py | 56 +++- .../triton_kernel/norm/gated_rmsnorm.py | 7 - .../basemodel/triton_kernel/norm/rmsnorm.py | 119 +++++++- .../triton_kernel/repack_kv_index.py | 7 +- lightllm/distributed/communication_op.py | 39 +++ lightllm/distributed/flashinfer_all_reduce.py | 35 +++ .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_infer/transformer_layer_infer.py | 77 +++-- .../layer_weights/transformer_layer_weight.py | 127 ++++++-- lightllm/models/qwen3next/model.py | 102 +++++++ .../triton_kernel/gdn_decode_pack.py | 284 ++++++++++++++++++ .../triton_kernel/shared_expert_gate.py | 108 +++++++ lightllm/server/api_cli.py | 2 +- lightllm/server/api_openai.py | 2 + lightllm/server/core/objs/sampling_params.py | 3 + .../model_infer/mode_backend/base_backend.py | 63 +++- .../mode_backend/chunked_prefill/impl.py | 17 +- .../mode_backend/generic_post_process.py | 278 +++++++++++++++-- lightllm/utils/sgl_utils.py | 4 +- 34 files changed, 1510 insertions(+), 111 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py create mode 100644 lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..08d15294bd 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -100,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) -def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] @@ -120,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) -def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..48deabaf4b 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -3,12 +3,16 @@ from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, get_scheduler_metadata from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +_DECODE_MAX_NUM_SPLITS = 32 +_DECODE_PACK_GQA = True + + class Fa3AttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -119,6 +123,7 @@ class Fa3DecodeAttState(BaseDecodeAttState): cu_seqlens_k: torch.Tensor = None page_table: torch.Tensor = None b_att_seq_len: torch.Tensor = None + scheduler_metadata: torch.Tensor = None # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 decode_max_q_seq_len: int = None @@ -179,8 +184,33 @@ def init_state(self): ) self.b_att_seq_len = self.infer_state.b_seq_len self.decode_max_q_seq_len = 1 + self._init_scheduler_metadata() return + def _init_scheduler_metadata(self): + if get_scheduler_metadata is None: + self.scheduler_metadata = None + return + + model = self.backend.model + self.scheduler_metadata = get_scheduler_metadata( + batch_size=self.b_att_seq_len.shape[0], + max_seqlen_q=self.decode_max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + num_heads=model.config["num_attention_heads"] // model.tp_world_size_, + num_heads_k=model.tp_k_head_num_, + headdim=model.head_dim_, + cache_seqlens=self.b_att_seq_len, + qkv_dtype=model.data_type, + headdim_v=model.head_dim_, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + page_size=1, + causal=True, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, + ) + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): super().copy_for_decode_cuda_graph(new_state) @@ -235,6 +265,9 @@ def _normal_decode_att( causal=True, window_size=window_size, softcap=0.0, + scheduler_metadata=self.scheduler_metadata, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..b5145ac932 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -6,6 +6,82 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_tensor_core_decode( + decode_wrapper, + indptr, + indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + indptr_host, + kv_lens_arr_host, + max_kv_len, +): + batch_size = len(last_page_len) + if batch_size != decode_wrapper._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + "mismatches the batch size set during initialization {}".format( + batch_size, decode_wrapper._fixed_batch_size + ) + ) + if len(indices) > len(decode_wrapper._paged_kv_indices_buf): + raise ValueError("The size of indices should be less than or equal to the allocated buffer") + + qo_indptr_host = getattr(decode_wrapper, "_lightllm_qo_indptr_host", None) + if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1: + from flashinfer.decode import _get_range_buf + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + decode_wrapper._lightllm_qo_indptr_host = qo_indptr_host + + if indptr_host is None: + indptr_host = indptr.to("cpu") + if kv_lens_arr_host is None: + from flashinfer.decode import get_seq_lens + + last_page_len_host = last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + if max_kv_len is None: + max_kv_len = max(kv_lens_arr_host).item() + + decode_wrapper._batch_size = batch_size + decode_wrapper._num_qo_heads = num_qo_heads + decode_wrapper._num_kv_heads = num_kv_heads + decode_wrapper._block_tables = None + decode_wrapper._max_kv_len = max_kv_len + + args = [ + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + decode_wrapper.is_cuda_graph_enabled, + head_dim, + head_dim, + False, + -1, + ] + if decode_wrapper._backend == "fa2": + args.extend([-1, False, 0]) + decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args) + decode_wrapper._pos_encoding_mode = "NONE" + decode_wrapper._window_left = -1 + decode_wrapper._logits_soft_cap = 0.0 + decode_wrapper._sm_scale = None + decode_wrapper._rope_scale = None + decode_wrapper._rope_theta = None + + class FlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -25,6 +101,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] self.q_data_type = model.data_type self.kv_data_type = model.data_type @@ -124,11 +204,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_last_page_len_buffer: torch.Tensor = None kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + kv_starts_host: torch.Tensor = None + kv_seq_lens_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: FlashInferAttBackend = self.backend device = self.infer_state.input_ids.device model = self.backend.model @@ -154,8 +234,21 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.infer_state.b_seq_len_cpu is not None: + self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -182,7 +275,8 @@ def init_state(self): def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( + _fast_plan_tensor_core_decode( + self.decode_wrapper, new_state.kv_starts, new_state.kv_indices, new_state.kv_last_page_len_buffer, @@ -190,9 +284,9 @@ def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): new_state.backend.tp_kv_head_num, new_state.backend.head_dim, 1, - q_data_type=new_state.backend.q_data_type, - kv_data_type=new_state.backend.kv_data_type, - non_blocking=True, + new_state.kv_starts_host, + new_state.kv_seq_lens_host, + new_state.infer_state.max_kv_seq_len, ) def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45a..8689839db4 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState): decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: MlaFlashInferAttBackend = self.backend model = self.backend.model device = self.infer_state.input_ids.device @@ -144,7 +142,13 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..e6d5735385 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu infer_state.b_mtp_index = model_input.b_mtp_index if model_input.is_prefill: if model_input.b_ready_cache_len is not None: @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_seq_len_cpu is not None: + new_model_input.b_seq_len_cpu = F.pad( + new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -562,6 +567,8 @@ def _decode( model_input=model_input, new_batch_size=infer_batch_size ) infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.skip_decode_att_wrapper_init = not need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -571,7 +578,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if need_capture: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -1037,6 +1044,9 @@ def autotune_layers(self): # 控制autotune的层数,用于适配不同模型 return self.config.get("first_k_dense_replace", 0) + 1 + def _autotune_extra_warmup(self): + return + @final @torch.no_grad() @post_empty_cache @@ -1106,6 +1116,11 @@ def _autotune_warmup(self): self.mem_manager.free_all() gc.collect() torch.cuda.empty_cache() + try: + self._autotune_extra_warmup() + except Exception as e: + logger.warning(f"extra autotune warmup failed: {str(e)}") + logger.exception(str(e)) self.layers_num = layer_num_bak torch.distributed.barrier() Autotuner.end_autotune_warmup() diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a82..81bf3cfd5c 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -42,6 +42,7 @@ class ModelInput: multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -64,6 +65,8 @@ def to_cuda(self): assert self.is_prefill self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) + if not self.b_seq_len.is_cuda: + self.b_seq_len_cpu = self.b_seq_len self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) if self.b_ready_cache_len is not None: diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..575f1ee25f 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -40,6 +40,7 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None self.b_seq_len: torch.Tensor = None + self.b_seq_len_cpu: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -56,6 +57,7 @@ def __init__(self): self.return_all_prompt_logics: bool = False self.multimodal_params: dict = None self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理 + self.skip_decode_att_wrapper_init: bool = False self.dist_group: CustomProcessGroup = None # 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号 diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index f0cc129c09..061658bc07 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -53,6 +53,18 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tens def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") + def _add_residual_ffn_norm(self, input_embdings, residual, infer_state: InferStateInfo, layer_weight): + add_rmsnorm = getattr(layer_weight.ffn_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is None: + input_embdings.add_(residual.view(-1, self.embed_dim_)) + return self._ffn_norm(input_embdings, infer_state, layer_weight) + return add_rmsnorm( + input=input_embdings, + residual=residual.view(-1, self.embed_dim_), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -89,10 +101,9 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.token_attention_forward(input1, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._add_residual_ffn_norm(input_embdings, o, infer_state, layer_weight) o = None - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..d9a77b39a5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,6 +134,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -150,6 +152,8 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..b54b03ee05 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,5 +63,7 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..bc0e86d7eb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,6 +76,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): output = fused_experts( hidden_states=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..417d001c72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,6 +30,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..fdda2b2139 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,6 +113,8 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return input_tensor @@ -129,6 +133,8 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -150,5 +156,7 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c3..33b59ac5a9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -2,7 +2,7 @@ from typing import Optional, Dict from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size -from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import add_rmsnorm_forward, rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward @@ -71,6 +71,21 @@ def __call__( ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def add_rmsnorm( + self, + input: torch.Tensor, + residual: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and residual.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim}, residual.ndim: {residual.ndim}, weight.ndim: {self.weight.ndim}" + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return add_rmsnorm_forward(x=input, residual=residual, weight=self.weight, eps=eps, out=out) + class GatedRMSNormWeight(RMSNormWeight): def _triton_forward( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..95dcba9836 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -221,10 +221,17 @@ def moe_align_fused_kernel( expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + expert_num: tl.constexpr, topk_num: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num @@ -282,6 +289,8 @@ def moe_align_fused( run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( @@ -291,8 +300,11 @@ def moe_align_fused( expert_to_weight, expert_token_num, token_num, + expert_num, topk_num, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +923,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -957,7 +971,12 @@ def fused_experts_impl( expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, @@ -1011,8 +1030,15 @@ def fused_experts_impl( bias=w2_bias, ) + has_shared_gate = shared_expert_out is not None moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + shared=None if not has_shared_gate else shared_expert_out[begin_chunk_idx:end_chunk_idx], + gate=None if not has_shared_gate else shared_expert_gate[begin_chunk_idx:end_chunk_idx], + run_config=( + None if not has_shared_gate else {"BLOCK_M": 1, "BLOCK_DIM": 128, "NUM_STAGE": 1, "num_warps": 2} + ), ) return out_hidden_states @@ -1035,6 +1061,8 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1082,8 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1075,6 +1105,8 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: pass @@ -1105,6 +1137,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1158,8 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1145,6 +1181,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1214,8 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1235,8 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return hidden_states else: @@ -1215,4 +1257,6 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..4f95cca7c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -14,12 +14,20 @@ def _moe_sum_reduce_kernel( output_ptr, output_stride_0, output_stride_1, + shared_ptr, + shared_stride_0, + shared_stride_1, + gate_ptr, + gate_stride_0, + gate_stride_1, token_num: int, topk_num: int, hidden_dim: int, BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_SHARED_GATE: tl.constexpr, + GATE_DIM: tl.constexpr, ): input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) @@ -42,12 +50,38 @@ def _moe_sum_reduce_kernel( for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tmp = tl.load(input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0) accumulator += tmp + if HAS_SHARED_GATE: + shared = tl.load( + shared_ptr + token_index * shared_stride_0 + offs_dim * shared_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + if GATE_DIM == 1: + gate = tl.load(gate_ptr + token_index * gate_stride_0).to(tl.float32) + tl.zeros( + (BLOCK_DIM,), dtype=tl.float32 + ) + else: + gate = tl.load( + gate_ptr + token_index * gate_stride_0 + offs_dim * gate_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + gate = 1.0 / (1.0 + tl.exp(-gate)) + accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) -def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} +def _get_moe_sum_reduce_static_key( + input: torch.Tensor, output: torch.Tensor, shared: torch.Tensor = None, gate: torch.Tensor = None +): + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + "has_shared_gate": shared is not None, + "gate_dim": 0 if gate is None else gate.shape[-1], + } def _get_moe_sum_reduce_configs(): @@ -67,12 +101,20 @@ def _get_moe_sum_reduce_configs(): run_key_func=lambda input: input.shape[0], mutates_args=["output"], ) -def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None): +def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate=None, run_config: Dict = None): assert input.is_contiguous() assert output.is_contiguous() token_num, topk_num, hidden_dim = input.shape assert output.shape[0] == token_num and output.shape[1] == hidden_dim + has_shared_gate = shared is not None + if has_shared_gate: + assert gate is not None + shared = shared.view(token_num, hidden_dim) + gate = gate.view(token_num, gate.shape[-1]) + assert shared.is_contiguous() + assert gate.is_contiguous() + assert gate.shape[1] in (1, hidden_dim) if not run_config: run_config = { @@ -97,12 +139,20 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = *input.stride(), output, *output.stride(), + shared if has_shared_gate else output, + shared.stride(0) if has_shared_gate else 0, + shared.stride(1) if has_shared_gate else 0, + gate if has_shared_gate else output, + gate.stride(0) if has_shared_gate else 0, + gate.stride(1) if has_shared_gate else 0, token_num=token_num, topk_num=topk_num, hidden_dim=hidden_dim, BLOCK_M=BLOCK_M, BLOCK_DIM=BLOCK_DIM, NUM_STAGE=NUM_STAGE, + HAS_SHARED_GATE=has_shared_gate, + GATE_DIM=gate.shape[1] if has_shared_gate else 0, num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index 8dc8558922..79f57ba051 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -3,6 +3,7 @@ import triton import triton.language as tl import os +from lightllm.common.triton_utils.autotuner import autotune rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @@ -48,6 +49,51 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) +@triton.jit +def _add_rms_norm_fwd_fused( + X, + R, + Y, + W, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + X += row * x_stride0 + R += row * r_stride0 + Y += row * y_stride0 + + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + x = x + r + tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask) + _var += x * x + + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + y = x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + y *= w + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + + def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out @@ -60,7 +106,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() + MAX_FUSED_SIZE = 65536 // x_arg.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) # print("BLOCK_SIZE:", BLOCK_SIZE) if N > BLOCK_SIZE: @@ -86,6 +132,77 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y +def _get_add_rmsnorm_configs(): + return [{"num_warps": nw} for nw in [4, 8, 16]] + + +def _get_add_rmsnorm_static_key(x_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor): + return { + "x_dtype": str(x_arg.dtype), + "out_dtype": str(y_arg.dtype), + "weight_dtype": "none" if weight is None else str(weight.dtype), + "N": x_arg.shape[1], + "has_weight": weight is not None, + } + + +@autotune( + kernel_name="add_rmsnorm_forward:v1", + configs_gen_func=_get_add_rmsnorm_configs, + static_key_func=_get_add_rmsnorm_static_key, + run_key_func=lambda x_arg: x_arg.shape[0], + mutates_args=["x_arg", "y_arg"], +) +def _add_rmsnorm_forward( + x_arg: torch.Tensor, + residual_arg: torch.Tensor, + y_arg: torch.Tensor, + weight: torch.Tensor, + eps: float, + run_config: dict = None, +): + M, N = x_arg.shape + MAX_FUSED_SIZE = 65536 // x_arg.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + if not run_config: + run_config = {"num_warps": rmsnorm_num_warps} + _add_rms_norm_fwd_fused[(M,)]( + x_arg, + residual_arg, + y_arg, + weight, + x_arg.stride(0), + x_arg.stride(1), + residual_arg.stride(0), + residual_arg.stride(1), + y_arg.stride(0), + y_arg.stride(1), + N, + eps, + HAS_WEIGHT=weight is not None, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=run_config["num_warps"], + ) + return y_arg + + +def add_rmsnorm_forward(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, out=None): + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, x.shape[-1]) + residual_arg = residual.view(-1, x.shape[-1]) + y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape == residual_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] + assert y.data_ptr() == y_arg.data_ptr() + _add_rmsnorm_forward(x_arg, residual_arg, y_arg, weight, eps) + return y + + def torch_rms_norm(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..57a1c4d0a3 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -34,10 +34,11 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): +def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index, zero_output: bool = True): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() + # Some flashinfer callers need zero-filled padding outside the valid indptr range. + if zero_output: + out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f15badde25..31ff5b8b65 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -100,6 +100,24 @@ def all_reduce(self, input_: torch.Tensor) -> None: return return dist.all_reduce(input_, group=self.device_group) + def all_reduce_residual_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if self.flashinfer_reduce is None: + return None + return self.flashinfer_reduce.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) @@ -235,6 +253,27 @@ def all_reduce( return dist.all_reduce(input_, op, group, async_op) +def all_reduce_residual_rmsnorm( + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, + alloc_func=torch.empty, +): + if _is_single_group(group=group): + return None + if isinstance(group, CustomProcessGroup): + return group.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + return None + + def all_gather_into_tensor( output_: torch.Tensor, input_: torch.Tensor, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac7..d469c3bc66 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -132,4 +132,39 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: input=inp, workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, ) + + def all_reduce_residual_rmsnorm( + self, + inp: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if ( + residual.shape != inp.shape + or residual.dtype != inp.dtype + or not residual.is_cuda + or norm_weight.dtype != inp.dtype + or norm_weight.shape[0] != inp.shape[-1] + ): + return None + if not self.should_use(inp): + return None + + residual_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + norm_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + flashinfer_comm.allreduce_fusion( + input=inp, + workspace=self._workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=norm_weight, + rms_eps=eps, + ) + return residual_out, norm_out diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..d9ac369960 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,24 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..3492041813 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,15 +116,14 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) - return shared_expert_out + gate = layer_weight.ffn_gate.mm(input) + return shared_expert_out, gate def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape @@ -135,15 +136,16 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_out=shared_expert_out, + shared_expert_gate=gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -158,7 +160,7 @@ def _moe_ffn_edp( is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) - ep_output.add_(shared_expert_out) + add_shared_expert_gate_(ep_output, shared_expert_out, gate) return ep_output def _get_qkv( @@ -169,13 +171,25 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -199,15 +213,24 @@ def _get_o( input, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + o_tensor = self._get_o_local(input=input, infer_state=infer_state, layer_weight=layer_weight) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + def _get_o_local( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) + sigmoid_mul_(input, infer_state.gate_value) infer_state.gate_value = None o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -257,8 +280,9 @@ def gdn_forward( else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -406,6 +430,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,18 +438,24 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating - # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -438,4 +469,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..51b702039b 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -11,6 +11,83 @@ QKVROWNMMWeight, QKGEMMANormWeight, ) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = 1 + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_tp_padded_head_num(self, head_num): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + if self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + raise ValueError( + f"head_num must be divisible by tp_world_size_ or vice versa, found {head_num} % {self.tp_world_size_}" + ) class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -23,25 +100,39 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) + qkv_quant = self.get_quant_method("qkv_proj") + gate_quant = self.get_quant_method("o_gate_proj") + if qkv_quant.method_name == "none" and gate_quant.method_name == "none": + self.qkvo_gate_proj = QKVGatedROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, + ) + else: + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=qkv_quant, + ) + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=gate_quant, + ) def _init_weight(self): if self.is_linear_attention_layer: diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a50..a95196abaf 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -17,6 +17,8 @@ from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.basemodel.batch_objs import ModelOutput +from lightllm.distributed import all_reduce, all_reduce_residual_rmsnorm logger = init_logger(__name__) @@ -51,6 +53,29 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] + def _autotune_extra_warmup(self): + if not self.trans_layers_weight: + return + + norm_weight = self.trans_layers_weight[0].ffn_norm_weight_ + add_rmsnorm = getattr(norm_weight, "add_rmsnorm", None) + if add_rmsnorm is None: + return + + hidden_dim = norm_weight.weight.shape[0] + max_batch_size = min(self.graph_max_batch_size, self.batch_max_tokens) + warmup_batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + warmup_batch_sizes = [bs for bs in warmup_batch_sizes if bs <= max_batch_size] + if max_batch_size not in warmup_batch_sizes: + warmup_batch_sizes.append(max_batch_size) + + for batch_size in sorted(set(warmup_batch_sizes)): + x = torch.zeros((batch_size, hidden_dim), dtype=self.data_type, device="cuda") + residual = torch.zeros_like(x) + out = torch.empty_like(x) + add_rmsnorm(input=x, residual=residual, eps=self.layers_infer[0].eps_, out=out) + return + def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -102,3 +127,80 @@ def _init_req_manager(self): self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args() ) return + + def _token_forward(self, infer_state: Qwen3NextInferStateInfo): + input_ids = infer_state.input_ids + input_embs = self.pre_infer.token_forward(input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + + next_att_normed = None + for i in range(self.layers_num): + layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i] + layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i] + + if next_att_normed is None: + input1 = layer._att_norm(input_embs, infer_state, layer_weight) + else: + input1 = next_att_normed + next_att_normed = None + + if layer.is_linear_attention_layer: + o = layer.token_attention_forward(input1, infer_state, layer_weight) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + o = None + else: + q, cache_kv = layer._get_qkv(input1, infer_state, layer_weight) + layer._post_cache_kv(cache_kv, infer_state, layer_weight) + o = layer._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = layer._get_o_local(o, infer_state, layer_weight) + fused = None + if layer.tp_world_size_ > 1: + fused = all_reduce_residual_rmsnorm( + o, + residual=input_embs.view(-1, layer.embed_dim_), + norm_weight=layer_weight.ffn_norm_weight_.weight, + eps=layer.eps_, + group=infer_state.dist_group, + alloc_func=layer.alloc_tensor, + ) + if fused is None: + if layer.tp_world_size_ > 1: + all_reduce(o, group=infer_state.dist_group) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + else: + input_embs, input1 = fused + o = None + + ffn_out = layer._ffn(input1, infer_state, layer_weight) + ffn_out = ffn_out.view(-1, layer.embed_dim_) + + if i + 1 < self.layers_num: + next_layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i + 1] + next_layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i + 1] + add_rmsnorm = getattr(next_layer_weight.att_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is not None: + next_att_normed = add_rmsnorm( + input=input_embs, + residual=ffn_out, + eps=next_layer.eps_, + alloc_func=next_layer.alloc_tensor, + ) + continue + + input_embs.add_(ffn_out) + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + predict_logits: torch.Tensor = self.post_infer.token_forward( + last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight + ) + + model_output = ModelOutput(logits=predict_logits.contiguous()) + if self.is_mtp_mode: + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + model_output.mtp_main_output_hiddens = input_embs.contiguous() + + if infer_state.is_cuda_graph: + model_output.to_no_ref_tensor() + + return model_output diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..a025e35c64 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,284 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + BLOCK_QKV: tl.constexpr, + BLOCK_GATE: tl.constexpr, +): + row = tl.program_id(0) + qkv_offsets = tl.arange(0, BLOCK_QKV) + + q_mask = qkv_offsets < q_dim + q_vals = tl.load(mixed_qkv + row * stride_m_b + qkv_offsets * stride_m_d, mask=q_mask, other=0.0) + tl.store(q_out + row * q_dim + qkv_offsets, q_vals, mask=q_mask) + + k_mask = qkv_offsets < k_dim + k_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + qkv_offsets) * stride_m_d, + mask=k_mask, + other=0.0, + ) + tl.store(k_out + row * k_dim + qkv_offsets, k_vals, mask=k_mask) + + v_mask = qkv_offsets < v_dim + v_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + k_dim + qkv_offsets) * stride_m_d, + mask=v_mask, + other=0.0, + ) + tl.store(v_out + row * v_dim + qkv_offsets, v_vals, mask=v_mask) + + z_vals = tl.load(z_raw + row * stride_z_b + qkv_offsets, mask=v_mask, other=0.0) + tl.store(z_out + row * v_dim + qkv_offsets, z_vals, mask=v_mask) + + gate_offsets = tl.arange(0, BLOCK_GATE) + gate_mask = gate_offsets < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + gate_offsets * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + gate_offsets * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + gate_offsets, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) + + +@torch.no_grad() +def pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) + block_gate = triton.next_power_of_2(gate_dim) + _pack_gdn_decode_kernel[(batch,)]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + BLOCK_QKV=block_qkv, + BLOCK_GATE=block_gate, + num_warps=4, + ) + return q, k, v, z, a, b + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + s0 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s1 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s2 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w0 = tl.load(conv_weight + offs * stride_w_d + 0 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w1 = tl.load(conv_weight + offs * stride_w_d + 1 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w2 = tl.load(conv_weight + offs * stride_w_d + 2 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w3 = tl.load(conv_weight + offs * stride_w_d + 3 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y = s0 * w0 + s1 * w1 + s2 * w2 + x * w3 + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, s1, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, s2, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..c2b110def6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,108 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _add_shared_expert_gate_kernel( + hidden, + shared, + gate, + stride_h_m: tl.constexpr, + stride_h_n: tl.constexpr, + stride_s_m: tl.constexpr, + stride_s_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + + hidden_ptrs = hidden + row * stride_h_m + offs * stride_h_n + shared_vals = tl.load(shared + row * stride_s_m + offs * stride_s_n, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + out = hidden_vals + shared_vals * gate_vals + tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + hidden_arg = hidden.view(-1, hidden.shape[-1]) + shared_arg = shared.view(-1, hidden.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert hidden_arg.shape == shared_arg.shape + assert gate_arg.shape[0] == hidden_arg.shape[0] and gate_arg.shape[1] in (1, hidden_arg.shape[1]) + _, n = hidden_arg.shape + block_n = triton.next_power_of_2(n) + _add_shared_expert_gate_kernel[(hidden_arg.shape[0],)]( + hidden_arg, + shared_arg, + gate_arg, + hidden_arg.stride(0), + hidden_arg.stride(1), + shared_arg.stride(0), + shared_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return hidden + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x_arg, + gate_arg, + x_arg.stride(0), + x_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7e40421140..8e9866a1d7 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -401,7 +401,7 @@ def make_argument_parser() -> argparse.ArgumentParser: default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages - (priority: flashinfer > fa3 > triton)""", + (priority: fa3 > flashinfer > triton)""", ) parser.add_argument( "--vit_att_backend", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 0d934c44c9..df79913e23 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -314,6 +314,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } @@ -822,6 +823,7 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) -> "n": request.n, "best_of": request.best_of, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } if request.max_completion_tokens is not None: diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c39559f5f6..3515fbf1a1 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -304,6 +304,7 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache + ("return_logprobs", ctypes.c_bool), # whether generated token logprobs are required by the caller ("seed", ctypes.c_int64), # random seed ] @@ -340,6 +341,7 @@ def init(self, tokenizer, **kwargs): self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) self.print_eos_token = kwargs.get("print_eos_token", False) + self.return_logprobs = kwargs.get("return_logprobs", True) self.seed = kwargs.get("seed", -1) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() @@ -486,6 +488,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "return_logprobs": self.return_logprobs, "seed": self.seed, } diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4323a62d1c..512b6ad2c2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -343,7 +343,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): + def _async_copy_next_token_infos_to_pin_mem( + self, + next_token_ids: torch.Tensor, + next_token_logprobs: Optional[torch.Tensor], + ): """ 这个函数会把next token id和logprobs保存到pinned memory中 这样可以保障post_handle 函数可以读取到正常的输出结果。 @@ -352,9 +356,13 @@ def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, key="next_token_ids", gpu_tensor=next_token_ids, ) - next_token_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="next_token_logprobs", - gpu_tensor=next_token_logprobs, + next_token_logprobs_cpu = ( + None + if next_token_logprobs is None + else g_pin_mem_manager.async_copy_from_gpu_tensor( + key="next_token_logprobs", + gpu_tensor=next_token_logprobs, + ) ) return next_token_ids_cpu, next_token_logprobs_cpu @@ -700,7 +708,7 @@ def _post_handle( self, run_reqs: List[InferReq], next_token_ids: List[int], - next_token_logprobs: List[float], + next_token_logprobs: Optional[List[float]], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, @@ -709,9 +717,18 @@ def _post_handle( extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 """ - for req_obj, next_token_id, next_token_logprob, pack in zip( - run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs - ): + if next_token_logprobs is None: + iter_items = zip(run_reqs, next_token_ids, run_reqs_update_packs) + else: + iter_items = zip(run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs) + + for item in iter_items: + if next_token_logprobs is None: + req_obj, next_token_id, pack = item + next_token_logprob = 0.0 + else: + req_obj, next_token_id, next_token_logprob, pack = item + req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack pack.handle( @@ -800,10 +817,38 @@ def _sample_and_scatter_token( mask=b_has_out, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs + next_token_ids, + next_token_logprobs, ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _can_decode_pre_post_before_prev_post_handle( + self, + run_reqs: List[InferReq], + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ) -> bool: + if not self.support_overlap: + return False + if extra_post_req_handle_func is not None or self.decode_mask_func is not None: + return False + if self.args.mtp_mode: + return False + + for req_obj in run_reqs: + if req_obj.mtp_step != 0: + return False + if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + return False + + shm_param = req_obj.sampling_param.shm_param + if not shm_param.ignore_eos: + return False + if len(req_obj.stop_sequences) != 0: + return False + if req_obj.cur_output_len + 1 >= shm_param.max_new_tokens: + return False + return True + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 792a10a788..231a98f853 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -163,11 +163,22 @@ def decode_normal( sync_event.record() # 第二阶段 - event_pack.notify_post_handle_and_wait_pre_post_handle() - update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + can_pre_post_early = self._can_decode_pre_post_before_prev_post_handle( + run_reqs=run_reqs, + extra_post_req_handle_func=self.extra_post_req_handle_func, + ) + if can_pre_post_early: + event_pack.notify_post_handle_event.set() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + event_pack.notify_forward_event.set() + event_pack.wait_pre_post_handle_event.wait() + else: + event_pack.notify_post_handle_and_wait_pre_post_handle() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) # 第三阶段 - event_pack.notify_forward_and_wait_post_handle() + if not can_pre_post_early: + event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() self._post_handle( run_reqs=run_reqs, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 5b29ea0510..14fd0c21c6 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Optional, Tuple, Union from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids @@ -7,8 +7,24 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args +_flashinfer_top_k_top_p_sampling_from_logits = None +_flashinfer_top_k_top_p_sampling_from_logits_checked = False +_flashinfer_top_k_top_p_sampling_from_probs = None +_flashinfer_top_k_top_p_sampling_from_probs_checked = False +_flashinfer_top_p_sampling_from_probs = None +_flashinfer_top_p_sampling_from_probs_checked = False +_flashinfer_top_k_sampling_from_probs = None +_flashinfer_top_k_sampling_from_probs_checked = False +_uniform_tensor_cache = {} +_softmax_out_cache = {} +_is_flashinfer_sampling_backend = None + def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): + fast_next_token_ids = _try_flashinfer_sample_without_penalty(logits, reqs) + if fast_next_token_ids is not None: + return fast_next_token_ids.view(-1), None + ( b_req_idx, b_temperatures, @@ -23,6 +39,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -75,7 +92,18 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): cu_invalid_token_num=cu_invalid_token_num, ) - logits.div_(b_temperatures.view((-1, 1))) + if b_temperatures is not None: + logits.div_(b_temperatures.view((-1, 1))) + + if is_all_greedy and not need_logprobs: + batch_next_token_ids = torch.argmax(logits, -1) + if get_env_start_args().mtp_mode: + batch_next_token_logprobs = torch.zeros( + batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device + ) + return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) + return batch_next_token_ids.view(-1), None + probs = torch.softmax(logits, dim=-1) if is_all_greedy: @@ -86,16 +114,82 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): elif skip_top_k and skip_top_p: # topk 等于整个词表,topp 等于1.0,等价于不进行topk topp过滤,直接进行随机采样,可以提升采样速度 batch_next_token_ids = _random_sample(probs, reqs, exist_req_use_random_seed) + if not need_logprobs: + return batch_next_token_ids.view(-1), None batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) else: batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample( - reqs, probs, b_top_ps, b_top_ks, exist_req_use_random_seed + reqs, + probs, + b_top_ps, + b_top_ks, + skip_top_k, + skip_top_p, + exist_req_use_random_seed, + need_logprobs, ) + if batch_next_token_logprobs is None: + return batch_next_token_ids.view(-1), None return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) +def _try_flashinfer_sample_without_penalty(logits: torch.Tensor, reqs: List[InferReq]) -> Optional[torch.Tensor]: + if not _is_flashinfer_sampling() or not reqs: + return None + + first_param = reqs[0].sampling_param.shm_param + top_p = first_param.top_p + top_k = first_param.top_k + temperature = first_param.temperature + vocab_size = reqs[0].vocab_size + + if top_k <= 1 or (top_k == vocab_size and top_p == 1.0): + return None + + for req in reqs: + shm_param = req.sampling_param.shm_param + if shm_param.return_logprobs: + return None + if req.generator is not None: + return None + if len(req.sampling_param.invalid_token_ids) != 0: + return None + if not shm_param.ignore_eos: + return None + if shm_param.presence_penalty != 0.0: + return None + if shm_param.frequency_penalty != 0.0: + return None + if shm_param.repetition_penalty != 1.0: + return None + if shm_param.temperature != temperature: + return None + if shm_param.top_p != top_p: + return None + if shm_param.top_k != top_k: + return None + + if temperature != 1.0: + logits.div_(temperature) + + if top_k == vocab_size and top_p != 1.0: + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) + + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) + return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) + + +def _is_flashinfer_sampling() -> bool: + global _is_flashinfer_sampling_backend + if _is_flashinfer_sampling_backend is None: + _is_flashinfer_sampling_backend = get_env_start_args().sampling_backend == "flashinfer" + return _is_flashinfer_sampling_backend + + def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): probs_sort, probs_idx = probs.sort(dim=-1, descending=True) @@ -107,13 +201,123 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _flashinfer_top_p_top_k_sample_from_logits( + logits: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], +) -> Optional[torch.Tensor]: + global _flashinfer_top_k_top_p_sampling_from_logits + global _flashinfer_top_k_top_p_sampling_from_logits_checked + + if not _flashinfer_top_k_top_p_sampling_from_logits_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_logits + except ImportError: + top_k_top_p_sampling_from_logits = None + _flashinfer_top_k_top_p_sampling_from_logits = top_k_top_p_sampling_from_logits + _flashinfer_top_k_top_p_sampling_from_logits_checked = True + + if _flashinfer_top_k_top_p_sampling_from_logits is None: + return None + + return _flashinfer_top_k_top_p_sampling_from_logits( + logits, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + + +def _flashinfer_top_p_sample_from_logits( + logits: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + probs = _softmax_out(logits) + return _flashinfer_top_p_sample_from_probs(probs, top_p) + + +def _get_uniform_tensor(value: Union[float, int], size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + key = (str(device), dtype, size, value) + tensor = _uniform_tensor_cache.get(key) + if tensor is None: + tensor = torch.full((size,), value, dtype=dtype, device=device) + _uniform_tensor_cache[key] = tensor + return tensor + + +def _softmax_out(logits: torch.Tensor) -> torch.Tensor: + key = (str(logits.device), logits.dtype, tuple(logits.shape)) + probs = _softmax_out_cache.get(key) + if probs is None: + probs = torch.empty_like(logits) + _softmax_out_cache[key] = probs + torch.ops.aten._softmax.out(logits, -1, False, out=probs) + return probs + + +def _get_flashinfer_top_k_top_p_sampling_from_probs(): + global _flashinfer_top_k_top_p_sampling_from_probs + global _flashinfer_top_k_top_p_sampling_from_probs_checked + + if not _flashinfer_top_k_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_probs + except ImportError: + top_k_top_p_sampling_from_probs = None + _flashinfer_top_k_top_p_sampling_from_probs = top_k_top_p_sampling_from_probs + _flashinfer_top_k_top_p_sampling_from_probs_checked = True + return _flashinfer_top_k_top_p_sampling_from_probs + + +def _flashinfer_top_p_sample_from_probs( + probs: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + global _flashinfer_top_p_sampling_from_probs + global _flashinfer_top_p_sampling_from_probs_checked + + if not _flashinfer_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_p_sampling_from_probs + except ImportError: + top_p_sampling_from_probs = None + _flashinfer_top_p_sampling_from_probs = top_p_sampling_from_probs + _flashinfer_top_p_sampling_from_probs_checked = True + + if _flashinfer_top_p_sampling_from_probs is None: + return None + + return _flashinfer_top_p_sampling_from_probs(probs, top_p, deterministic=True, check_nan=False) + + +def _flashinfer_top_k_sample_from_probs(probs: torch.Tensor, top_k: Union[torch.Tensor, int]) -> Optional[torch.Tensor]: + global _flashinfer_top_k_sampling_from_probs + global _flashinfer_top_k_sampling_from_probs_checked + + if not _flashinfer_top_k_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_sampling_from_probs + except ImportError: + top_k_sampling_from_probs = None + _flashinfer_top_k_sampling_from_probs = top_k_sampling_from_probs + _flashinfer_top_k_sampling_from_probs_checked = True + + if _flashinfer_top_k_sampling_from_probs is None: + return None + + return _flashinfer_top_k_sampling_from_probs(probs, top_k, deterministic=True, check_nan=False) + + def _top_p_top_k_sample( reqs: List[InferReq], probs: torch.Tensor, - b_top_ps: torch.Tensor, - b_top_ks: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], + skip_top_k: bool, + skip_top_p: bool, exist_req_use_random_seed: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: + need_logprobs: bool, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: sampling_backend = get_env_start_args().sampling_backend if sampling_backend == "triton": @@ -123,19 +327,32 @@ def _top_p_top_k_sample( else: sampled_index = _random_sample(probs_sort, reqs, exist_req_use_random_seed).view(-1, 1) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) + if not need_logprobs: + return next_token_ids.view(-1), None next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) elif sampling_backend == "flashinfer": - from flashinfer.sampling import top_k_top_p_sampling_from_probs - - batch_next_token_ids = top_k_top_p_sampling_from_probs( - probs, - b_top_ks, - b_top_ps, - filter_apply_order="joint", - check_nan=False, - ) + if skip_top_k: + batch_next_token_ids = _flashinfer_top_p_sample_from_probs(probs, b_top_ps) + elif skip_top_p: + batch_next_token_ids = _flashinfer_top_k_sample_from_probs(probs, b_top_ks) + else: + top_k_top_p_sampling_from_probs = _get_flashinfer_top_k_top_p_sampling_from_probs() + if top_k_top_p_sampling_from_probs is None: + raise ImportError("flashinfer.sampling.top_k_top_p_sampling_from_probs is not available") + batch_next_token_ids = top_k_top_p_sampling_from_probs( + probs, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + if batch_next_token_ids is None: + raise ImportError("flashinfer sampling op is not available") + if not need_logprobs: + return batch_next_token_ids.view(-1), None int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64) int64_batch_next_token_ids[:] = batch_next_token_ids batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) @@ -165,6 +382,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k = True skip_top_p = True exist_req_use_random_seed = False + need_logprobs = False + all_temperature_one = True # invalid token ids invalid_token_ids: List[int] = [] @@ -192,6 +411,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = False if req_obj.generator is not None: exist_req_use_random_seed = True + if shm_param.return_logprobs: + need_logprobs = True + if shm_param.temperature != 1.0: + all_temperature_one = False req_idxes.append(req_obj.req_idx) invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) cu_invalid_token_num.append(invalid_token_num_start) @@ -200,13 +423,25 @@ def _get_post_sample_tensors(reqs: List[InferReq]): invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) - temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) - top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) - top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + temperatures_cpu = ( + None + if all_temperature_one + else g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + ) + sampling_backend = get_env_start_args().sampling_backend + need_top_k_top_p_tensors = (not is_all_greedy) and (not (skip_top_k and skip_top_p)) + need_top_ps_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_p) + need_top_ks_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_k) + top_ps_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) if need_top_ps_tensor else None + ) + top_ks_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) if need_top_ks_tensor else None + ) if has_invalid_token_ids: invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list( @@ -218,9 +453,9 @@ def _get_post_sample_tensors(reqs: List[InferReq]): return ( req_idxes_cpu.cuda(non_blocking=True), - temperatures_cpu.cuda(non_blocking=True), - top_ps_cpu.cuda(non_blocking=True), - top_ks_cpu.cuda(non_blocking=True), + temperatures_cpu.cuda(non_blocking=True) if temperatures_cpu is not None else None, + top_ps_cpu.cuda(non_blocking=True) if top_ps_cpu is not None else None, + top_ks_cpu.cuda(non_blocking=True) if top_ks_cpu is not None else None, length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, @@ -230,4 +465,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506d..b79a554f48 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -17,14 +17,16 @@ ) try: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, get_scheduler_metadata flash_attn_varlen_func = flash_attn_varlen_func flash_attn_with_kvcache = flash_attn_with_kvcache + get_scheduler_metadata = get_scheduler_metadata merge_state_v2 = sgl_ops.merge_state_v2 except: flash_attn_varlen_func = None flash_attn_with_kvcache = None + get_scheduler_metadata = None merge_state_v2 = None logger.warning( "sgl_kernel is not installed, or the installed version did not support fa3. \