diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..26f2b338b7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,6 +134,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -150,6 +151,7 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..9db303f95e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,5 +63,6 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..5c76f18623 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,7 +76,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): + assert shared_expert_gate is None, "fused shared expert as MoE is not supported by DeepGEMM fused MoE" output = fused_experts( hidden_states=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..c4d1681408 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,7 +30,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): + assert shared_expert_gate is None, "fused shared expert as MoE is not supported by Marlin fused MoE" w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..20e9c0b113 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -62,27 +62,6 @@ def _select_experts( topk_weights.mul_(self.routed_scaling_factor) if per_expert_scale is not None: topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts, - end=self.n_routed_experts + self.num_fused_shared_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) return topk_weights, topk_ids def _fused_experts( @@ -94,11 +73,18 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn + if shared_expert_gate is not None: + assert ( + type(self) is FuseMoeTriton + ), "fused shared expert as MoE is only supported by the Triton fused MoE implementation" + assert self.num_fused_shared_experts > 0, "shared_expert_gate requires fused shared experts" + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_experts fused_experts( @@ -111,6 +97,8 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_id=self.n_routed_experts if self.num_fused_shared_experts > 0 else -1, + shared_expert_gate=shared_expert_gate, ) return input_tensor @@ -129,6 +117,7 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -150,5 +139,6 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + shared_expert_gate=shared_expert_gate, ) return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index fb50398368..6f3f9fe62c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -53,13 +53,8 @@ def __init__( ) -> None: self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() - self.repeat_times = 1 - assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( - f"kv_head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by kv_head_num, " - f"but found: {kv_head_num} % {self.tp_world_size_}" - ) - kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + self.repeat_times = self._get_repeat_times(kv_head_num) + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.repeat_times) * head_dim out_dims = [kv_hidden_size, kv_hidden_size] super().__init__( in_dim=in_dim, @@ -78,18 +73,19 @@ def __init__( repeat_times=self.repeat_times, ) - def _get_tp_padded_head_num(self, head_num: int): - if head_num % self.tp_world_size_ == 0: - return head_num // self.tp_world_size_ - elif self.tp_world_size_ % head_num == 0: - self.repeat_times = self.tp_world_size_ // head_num - return self.repeat_times * head_num // self.tp_world_size_ + def _get_repeat_times(self, kv_head_num: int) -> int: + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 else: - raise ValueError( - f"head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by head_num, " - f"but found: {head_num} % {self.tp_world_size_}" - ) + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int: + return repeat_times * head_num // self.tp_world_size_ class QKVROWNMMWeight(MMWeightTpl): @@ -109,17 +105,12 @@ def __init__( self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.q_repeat_times = 1 - self.kv_repeat_times = 1 + self.kv_repeat_times = self._get_kv_repeat_times(kv_head_num) assert q_head_num % self.tp_world_size_ == 0, ( f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}" ) - assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( - f"kv_head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by kv_head_num, " - f"but found: {kv_head_num} % {self.tp_world_size_}" - ) q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim - kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.kv_repeat_times) * head_dim out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size] super().__init__( in_dim=in_dim, @@ -157,18 +148,19 @@ def _get_param_slicer(self, sub_child_index: int): else: return self.kv_param_slicer - def _get_tp_padded_head_num(self, head_num: int): - if head_num % self.tp_world_size_ == 0: - return head_num // self.tp_world_size_ - elif self.tp_world_size_ % head_num == 0: - self.kv_repeat_times = self.tp_world_size_ // head_num - return self.kv_repeat_times * head_num // self.tp_world_size_ + def _get_kv_repeat_times(self, kv_head_num: int) -> int: + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 else: - raise ValueError( - f"head_num must be divisible by tp_world_size_ or " - f"tp_world_size_ must be divisible by head_num, " - f"but found: {head_num} % {self.tp_world_size_}" - ) + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int: + return repeat_times * head_num // self.tp_world_size_ class ROWBMMWeight(BMMWeightTpl): diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..4d7acea482 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -217,22 +217,51 @@ def moe_align1( def moe_align_fused_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] + shared_expert_gate_ptr, # [token_num, 1] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + routed_topk_num: tl.constexpr, + expert_num: tl.constexpr, topk_num: tl.constexpr, + shared_expert_id: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, + HAS_SHARED_EXPERT_GATE: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + tl.debug_barrier() + if shared_expert_id >= 0: + tl.store(expert_token_num_ptr + shared_expert_id, token_num, mask=token_block == 0) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) + token_ids = offs // topk_num + topk_offsets = offs - token_ids * topk_num + routed_offsets = token_ids * routed_topk_num + topk_offsets + is_shared_expert = topk_offsets >= routed_topk_num + + expert_ids = tl.load(topk_ids_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0) + expert_ids = tl.where(is_shared_expert, shared_expert_id, expert_ids) + weights = tl.load(topk_weights_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0.0) + if HAS_SHARED_EXPERT_GATE: + shared_weights = tl.load(shared_expert_gate_ptr + token_ids, mask=mask & is_shared_expert, other=0.0).to( + tl.float32 + ) + shared_weights = tl.sigmoid(shared_weights) + else: + shared_weights = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32) + weights = tl.where(is_shared_expert, shared_weights, weights) - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) + # Shared expert appears exactly once per token, so its position is deterministic. + routed_write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask & (is_shared_expert == 0)) + write_pos = tl.where(is_shared_expert, token_ids, routed_write_pos) # 按 token 顺序写 index 和 weight tl.store( @@ -249,8 +278,11 @@ def moe_align_fused_kernel( def _get_moe_align_fused_static_key( topk_weights: torch.Tensor, + shared_expert_id: int = -1, ) -> dict: topk_num = topk_weights.shape[1] + if shared_expert_id >= 0: + topk_num += 1 return { "topk_num": topk_num, } @@ -275,24 +307,43 @@ def _get_moe_align_fused_configs(): mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + shared_expert_id: int = -1, + shared_expert_gate: Optional[torch.Tensor] = None, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape + token_num, routed_topk_num = topk_ids.shape + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) if run_config is None: run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE + if shared_expert_gate is not None: + shared_expert_gate = shared_expert_gate.view(token_num, 1) grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( topk_ids, topk_weights, + shared_expert_gate if shared_expert_gate is not None else topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, token_num, + routed_topk_num, + expert_num, topk_num, + shared_expert_id, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), + HAS_SHARED_EXPERT_GATE=shared_expert_gate is not None, num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +962,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_id: int = -1, + shared_expert_gate: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -922,7 +975,8 @@ def fused_experts_impl( num_tokens, _ = hidden_states.shape E, N, _ = w1.shape CHUNK_SIZE = FFN_MOE_CHUNK_SIZE - topk_num = topk_ids.shape[1] + routed_topk_num = topk_ids.shape[1] + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) M = min(num_tokens, CHUNK_SIZE) intermediate_cache13_shared = alloc_tensor_func( @@ -954,20 +1008,30 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_shared_expert_gate = ( + shared_expert_gate[begin_chunk_idx:end_chunk_idx] if shared_expert_gate is not None else None + ) expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, expert_token_num=expert_to_token_num, topk_ids=curr_topk_ids, topk_weights=curr_topk_weights, + shared_expert_id=shared_expert_id, + shared_expert_gate=curr_shared_expert_gate, ) reused_mblock_infos = grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, curr_hidden_states, a1_scale, expert_to_token_num, @@ -993,7 +1057,7 @@ def fused_experts_impl( ) grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, intermediate_cache2.view(-1, N // 2), a2_scale, expert_to_token_num, @@ -1012,7 +1076,8 @@ def fused_experts_impl( ) moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], ) return out_hidden_states @@ -1035,6 +1100,7 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1120,7 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) @@ -1075,6 +1142,7 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, ) -> None: pass @@ -1105,7 +1173,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, @@ -1124,6 +1193,7 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) @@ -1145,7 +1215,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1176,7 +1247,32 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, + shared_expert_gate: Optional[torch.Tensor] = None, ): + if shared_expert_gate is not None: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + use_fp8_w8a8, + use_int8_w8a16, + w1_bias, + w2_bias, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + layout=layout, + alpha=alpha, + limit=limit, + shared_expert_id=shared_expert_id, + shared_expert_gate=shared_expert_gate, + ) + if inplace: torch.ops.lightllm.inplace_fused_experts_impl( hidden_states, @@ -1195,6 +1291,7 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) return hidden_states else: @@ -1215,4 +1312,5 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 45c7ea73c6..a63d92692e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -122,7 +122,7 @@ def silu_and_mul_fwd( alpha=None, run_config=None, ): - assert input.is_contiguous() + assert input.stride(-1) == 1 assert output.is_contiguous() assert (limit is None and alpha is None) or (limit is not None and alpha is not None) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..28221344b8 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -47,7 +47,11 @@ def _moe_sum_reduce_kernel( def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + } def _get_moe_sum_reduce_configs(): diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..cff020ea40 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -37,9 +37,9 @@ def _parse_config(self): self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # enable_fused_shared_experts can only work with tensor parallelism - assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + start_args = get_env_start_args() + if start_args.enable_fused_shared_experts and not start_args.enable_ep_moe and self.is_moe: + # fused shared experts can only work with tensor parallelism self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..649db03b11 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,19 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkv_gate_out = layer_weight.qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_logics_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index e4d80e6ff9..e6f40125f9 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,19 +116,17 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) + gate = layer_weight.shared_expert_gate.mm(input) + sigmoid_mul_(shared_expert_out, gate) return shared_expert_out def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) - hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) + shared_expert_gate = layer_weight.shared_expert_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -135,9 +135,9 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_gate=shared_expert_gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( @@ -169,13 +169,19 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkv_gate_out = layer_weight.qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_logics_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -204,8 +210,8 @@ def _get_o( if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) - infer_state.gate_value = None + sigmoid_mul_(input, infer_state.gate_logics_value) + infer_state.gate_logics_value = None o_tensor = layer_weight.o_proj.mm(input) o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor @@ -257,8 +263,9 @@ def gdn_forward( else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -406,6 +413,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,18 +421,25 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.conv_kernel_dim, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating; the kernel reads the - # q/k/v/a/b column views directly via per-token strides (no copies) - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -438,4 +453,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py b/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py new file mode 100644 index 0000000000..c920b23fbd --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/qkv_gated_rowmm_weight.py @@ -0,0 +1,75 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = self._get_kv_repeat_times(kv_head_num) + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.kv_repeat_times) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_kv_repeat_times(self, kv_head_num): + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + if kv_head_num % self.tp_world_size_ == 0: + return 1 + return self.tp_world_size_ // kv_head_num + + def _get_tp_padded_head_num(self, head_num, repeat_times): + return repeat_times * head_num // self.tp_world_size_ diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..60901ad6b9 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -10,7 +10,9 @@ TpParameterWeight, QKVROWNMMWeight, QKGEMMANormWeight, + FusedMoeWeight, ) +from lightllm.models.qwen3next.layer_weights.qkv_gated_rowmm_weight import QKVGatedROWNMMWeight class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -22,25 +24,17 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed - q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + qkv_quant = self.get_quant_method("qkv_proj") + self.qkvo_gate_proj = QKVGatedROWNMMWeight( in_dim=in_dim, q_head_num=self.q_head_num_, kv_head_num=self.k_head_num_, head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) - self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, ) def _init_weight(self): @@ -57,8 +51,47 @@ def _init_weight(self): self._init_norm() def _init_moe(self): - super()._init_moe() - self._init_gated_ffn() + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + enable_ep_moe = get_env_start_args().enable_ep_moe + # Fused shared expert is only supported in TP mode. EP keeps the shared + # expert as a separate FFN and adds its output after routed MoE. + self.num_fused_shared_experts = 0 if enable_ep_moe else 1 + self.shared_expert_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + if enable_ep_moe: + self._init_moe_shared_expert_ffn() return def _init_norm(self): @@ -81,54 +114,25 @@ def _init_norm(self): data_type=self.data_type_, ) - def _init_gated_ffn(self): + def _init_moe_shared_expert_ffn(self): hidden_size = self.network_config_["hidden_size"] - if "shared_expert_intermediate_size" not in self.network_config_: - return prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" inter_size = self.network_config_["shared_expert_intermediate_size"] - if get_env_start_args().enable_ep_moe: - self.gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("gate_up_proj"), - tp_rank=0, - tp_world_size=1, - ) - self.down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("down_proj"), - tp_rank=0, - tp_world_size=1, - ) - else: - self.gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("gate_up_proj"), - ) - self.down_proj = COLMMWeight( - in_dim=inter_size, - out_dims=[hidden_size], - weight_names=f"{prefix}.down_proj.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("down_proj"), - ) - - self.ffn_gate = ROWMMWeight( + self.gate_up_proj = ROWMMWeight( in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], data_type=self.data_type_, - bias_names=None, - quant_method=None, + quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), tp_rank=0, tp_world_size=1, ) @@ -143,6 +147,29 @@ def _split_q_with_gate(self, weights): weights[self._q_weight_name] = _q_proj weights[self._o_gate_weight_name] = _gate_proj + def _rename_shared_expert_to_moe_expert(self, weights): + if self.num_fused_shared_experts != 1: + return + assert not get_env_start_args().enable_ep_moe, "fused shared expert is only supported in TP mode" + assert self.num_fused_shared_experts == 1, "only one fused shared expert is supported" + + # When the shared expert is fused into MoE, load it as the last routed expert. + # The fused MoE kernel then treats expert id n_routed_experts as this shared expert. + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts.{self.n_routed_experts}" + suffixes = [ + self.experts.quant_method.weight_suffix, + self.experts.quant_method.weight_scale_suffix, + self.experts.quant_method.weight_zero_point_suffix, + ] + for proj_name in ("gate_proj", "up_proj", "down_proj"): + for suffix in suffixes: + if suffix is None: + continue + old_name = f"{old_prefix}.{proj_name}.{suffix}" + if old_name in weights: + weights[f"{new_prefix}.{proj_name}.{suffix}"] = weights[old_name] + def _parse_config(self): super()._parse_config() self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] @@ -288,6 +315,8 @@ def _parse_linear_conv1d(self, weight): def load_hf_weights(self, weights): self._split_q_with_gate(weights) + if self.is_moe: + self._rename_shared_expert_to_moe_expert(weights) if self.is_linear_attention_layer: self._preprocess_weight(weights) super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..c4efec47ea --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,177 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + KERNEL_SIZE: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + # KERNEL_SIZE is a constexpr, so Triton fully unrolls these loops for each conv size. + y = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for i in tl.static_range(0, KERNEL_SIZE - 1): + s = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + i * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w = tl.load(conv_weight + offs * stride_w_d + i * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y += s * w + + w = tl.load(conv_weight + offs * stride_w_d + (KERNEL_SIZE - 1) * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y += x * w + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + for i in tl.static_range(0, KERNEL_SIZE - 2): + next_s = tl.load( + conv_state + state_idx * stride_s_b + offs * stride_s_d + (i + 1) * stride_s_w, mask=mask, other=0.0 + ) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + i * stride_s_w, next_s, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + (KERNEL_SIZE - 2) * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + conv_size: int, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + assert conv_size >= 2, f"conv kernel size must be at least 2, got {conv_size}" + assert mixed_qkv.shape[1] == conv_dim, f"mixed_qkv shape mismatch: {mixed_qkv.shape[1]} != {conv_dim}" + assert conv_weight.shape[0] == conv_dim, f"conv_weight shape mismatch: {conv_weight.shape[0]} != {conv_dim}" + assert conv_weight.shape[1] == conv_size, f"conv_weight kernel mismatch: {conv_weight.shape[1]} != {conv_size}" + assert conv_state.shape[1] == conv_dim, f"conv_state shape mismatch: {conv_state.shape[1]} != {conv_dim}" + assert ( + conv_state.shape[2] >= conv_size - 1 + ), f"conv_state width must be at least conv_size - 1, got {conv_state.shape[2]} and {conv_size}" + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + conv_size, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..8b73cfd74d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,50 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = tl.sigmoid(gate_vals) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x=x_arg, + gate=gate_arg, + stride_x_m=x_arg.stride(0), + stride_x_n=x_arg.stride(1), + stride_g_m=gate_arg.stride(0), + stride_g_n=gate_arg.stride(1), + N=n, + GATE_N=gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..04e0187452 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -683,7 +683,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", + help="""Whether to enable fused shared experts for supported MoE models. It is auto-enabled when supported.""", ) parser.add_argument( "--mtp_mode", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cf431d650..c9b82e1e0c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -23,6 +23,7 @@ has_vision_module, is_linear_att_mixed_model, auto_set_max_req_total_len, + auto_set_fused_shared_experts, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args @@ -76,6 +77,7 @@ def normal_or_p_d_start(args): args: StartArgs = args auto_set_max_req_total_len(args) + auto_set_fused_shared_experts(args) set_unique_server_name(args) if args.enable_mps: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40c8028158..bfc03cd542 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -157,6 +157,7 @@ class StartArgs: enable_ep_moe: bool = field(default=False) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) mtp_mode: Optional[str] = field( default=None, metadata={ diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c8d7373d54..85d21477b0 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -146,6 +146,44 @@ def auto_set_max_req_total_len(args) -> None: logger.info(f"auto derived max_req_total_len={args.max_req_total_len} from model config") +def auto_set_fused_shared_experts(args) -> None: + """ + Route fused shared experts to supported model families and write the final + decision to `args.enable_fused_shared_experts`. + """ + + if args.enable_fused_shared_experts: + logger.info("skip auto setting fused shared experts: already enabled") + return + + if args.enable_ep_moe: + logger.info("do not enable fused shared experts: EP MoE uses a separate implementation") + return + + model_dir = args.model_dir + if not model_dir: + logger.info("do not enable fused shared experts: model_dir is empty") + return + + model_type = get_model_type(model_dir) + supported_model_types = { + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "qwen3_next", + "qwen3_5", + "qwen3_5_text", + "qwen3_5_moe", + "qwen3_5_moe_text", + } + if model_type not in supported_model_types: + logger.info(f"do not enable fused shared experts: unsupported model_type={model_type}") + return + + args.enable_fused_shared_experts = True + logger.info(f"auto enable fused shared experts for model_type={model_type}") + + def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): config_json = get_config_json(model_path) for key in key_name: