Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
shared_expert_gate: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
Expand All @@ -150,6 +151,7 @@ def experts(
num_expert_group=num_expert_group,
is_prefill=is_prefill,
per_expert_scale=self.per_expert_scale,
shared_expert_gate=shared_expert_gate,
)

def low_latency_dispatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,6 @@ def __call__(
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
shared_expert_gate: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
shared_expert_gate: Optional[torch.Tensor] = None,
):
assert shared_expert_gate is None, "fused shared expert as MoE is not supported by DeepGEMM fused MoE"
output = fused_experts(
hidden_states=input_tensor,
w13=w13,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
shared_expert_gate: Optional[torch.Tensor] = None,
):
assert shared_expert_gate is None, "fused shared expert as MoE is not supported by Marlin fused MoE"

w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point
w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,6 @@ def _select_experts(
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
start=self.n_routed_experts,
end=self.n_routed_experts + self.num_fused_shared_experts,
step=1,
dtype=topk_ids.dtype,
device="cuda",
)
.view(1, self.num_fused_shared_experts)
.repeat(topk_ids.shape[0], 1)
)
pad_topk_weights = torch.full(
(topk_weights.shape[0], self.num_fused_shared_experts),
fill_value=1.0,
device="cuda",
dtype=topk_weights.dtype,
)

topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1)
return topk_weights, topk_ids

def _fused_experts(
Expand All @@ -94,11 +73,18 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: bool = False,
shared_expert_gate: Optional[torch.Tensor] = None,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn

if shared_expert_gate is not None:
assert (
type(self) is FuseMoeTriton
), "fused shared expert as MoE is only supported by the Triton fused MoE implementation"
assert self.num_fused_shared_experts > 0, "shared_expert_gate requires fused shared experts"

from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_experts

fused_experts(
Expand All @@ -111,6 +97,8 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
shared_expert_id=self.n_routed_experts if self.num_fused_shared_experts > 0 else -1,
shared_expert_gate=shared_expert_gate,
)
return input_tensor

Expand All @@ -129,6 +117,7 @@ def __call__(
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
shared_expert_gate: Optional[torch.Tensor] = None,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -150,5 +139,6 @@ def __call__(
topk_ids=topk_ids,
router_logits=router_logits,
is_prefill=is_prefill,
shared_expert_gate=shared_expert_gate,
)
return output
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,8 @@ def __init__(
) -> None:
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
self.repeat_times = 1
assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, (
f"kv_head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by kv_head_num, "
f"but found: {kv_head_num} % {self.tp_world_size_}"
)
kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim
self.repeat_times = self._get_repeat_times(kv_head_num)
kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.repeat_times) * head_dim
out_dims = [kv_hidden_size, kv_hidden_size]
super().__init__(
in_dim=in_dim,
Expand All @@ -78,18 +73,19 @@ def __init__(
repeat_times=self.repeat_times,
)

def _get_tp_padded_head_num(self, head_num: int):
if head_num % self.tp_world_size_ == 0:
return head_num // self.tp_world_size_
elif self.tp_world_size_ % head_num == 0:
self.repeat_times = self.tp_world_size_ // head_num
return self.repeat_times * head_num // self.tp_world_size_
def _get_repeat_times(self, kv_head_num: int) -> int:
assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, (
f"kv_head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by kv_head_num, "
f"but found: {kv_head_num} % {self.tp_world_size_}"
)
if kv_head_num % self.tp_world_size_ == 0:
return 1
else:
raise ValueError(
f"head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by head_num, "
f"but found: {head_num} % {self.tp_world_size_}"
)
return self.tp_world_size_ // kv_head_num

def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int:
return repeat_times * head_num // self.tp_world_size_


class QKVROWNMMWeight(MMWeightTpl):
Expand All @@ -109,17 +105,12 @@ def __init__(
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
self.q_repeat_times = 1
self.kv_repeat_times = 1
self.kv_repeat_times = self._get_kv_repeat_times(kv_head_num)
assert q_head_num % self.tp_world_size_ == 0, (
f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}"
)
assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, (
f"kv_head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by kv_head_num, "
f"but found: {kv_head_num} % {self.tp_world_size_}"
)
q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim
kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim
kv_hidden_size = self._get_tp_padded_head_num(kv_head_num, self.kv_repeat_times) * head_dim
out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size]
super().__init__(
in_dim=in_dim,
Expand Down Expand Up @@ -157,18 +148,19 @@ def _get_param_slicer(self, sub_child_index: int):
else:
return self.kv_param_slicer

def _get_tp_padded_head_num(self, head_num: int):
if head_num % self.tp_world_size_ == 0:
return head_num // self.tp_world_size_
elif self.tp_world_size_ % head_num == 0:
self.kv_repeat_times = self.tp_world_size_ // head_num
return self.kv_repeat_times * head_num // self.tp_world_size_
def _get_kv_repeat_times(self, kv_head_num: int) -> int:
assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, (
f"kv_head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by kv_head_num, "
f"but found: {kv_head_num} % {self.tp_world_size_}"
)
if kv_head_num % self.tp_world_size_ == 0:
return 1
else:
raise ValueError(
f"head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by head_num, "
f"but found: {head_num} % {self.tp_world_size_}"
)
return self.tp_world_size_ // kv_head_num

def _get_tp_padded_head_num(self, head_num: int, repeat_times: int) -> int:
return repeat_times * head_num // self.tp_world_size_


class ROWBMMWeight(BMMWeightTpl):
Expand Down
Loading
Loading