Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
moe_layer_index: int = 0,
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
Expand All @@ -50,6 +51,7 @@ def __init__(
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.moe_layer_index = moe_layer_index
self._init_config(network_config)
self._init_redundancy_expert_params()
self._init_parallel_params()
Expand Down Expand Up @@ -130,6 +132,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
Expand All @@ -145,6 +148,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
moe_layer_index=self.moe_layer_index,
microbatch_index=microbatch_index,
)

def low_latency_dispatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
moe_layer_index: int = 0,
) -> None:
network_config["norm_topk_prob"] = None
super().__init__(
Expand All @@ -62,6 +64,7 @@ def __init__(
num_fused_shared_experts=num_fused_shared_experts,
layer_num=layer_num,
network_config=network_config,
moe_layer_index=moe_layer_index,
)

self.hidden_size = network_config["hidden_size"]
Expand Down Expand Up @@ -144,10 +147,15 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
):

topk_weights, topk_ids = self._router(router_logits, top_k)

# Rollout router replay
if g_routing_capture_manager is not None:
g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager


class FuseMoeTriton(FuseMoeBaseImpl):
Expand Down Expand Up @@ -124,6 +125,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -136,6 +139,10 @@ def __call__(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)

if g_routing_capture_manager is not None and moe_layer_index is not None:
g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)

output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
Expand Down
155 changes: 155 additions & 0 deletions lightllm/common/basemodel/routing_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import numpy as np
from typing import Optional
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_current_rank_in_dp
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)


def get_routing_config_shm() -> SharedArray:
"""Get shared memory for MoE routing config: arr[0]=num_moe_layers, arr[1]=topk."""
service_name = get_unique_server_name()
return SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)


class RoutingCaptureManager:
"""Captures MoE routing decisions"""

def __init__(
self,
num_moe_layers: int,
topk: int,
num_experts: int,
batch_max_tokens: int,
kv_cache_size: int,
enable_overlap: bool = False,
):
self.num_moe_layers = num_moe_layers
self.topk = topk
self.num_experts = num_experts
self.batch_max_tokens = batch_max_tokens
self.kv_cache_size = kv_cache_size

self.dtype = torch.int8 if num_experts <= 127 else torch.int16
dtype_bytes = 1 if self.dtype == torch.int8 else 2

self.num_slots = 2 if enable_overlap else 1

gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes
self.gpu_buffer = torch.zeros(
(self.num_slots, num_moe_layers, batch_max_tokens, topk),
dtype=self.dtype,
device="cuda",
)

cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
self.cpu_buffer = torch.zeros(
(num_moe_layers, kv_cache_size, topk),
dtype=self.dtype,
device="cpu",
pin_memory=True,
)

self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)]
self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)]

dtype_name = "int8" if self.dtype == torch.int8 else "int16"
logger.info(
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, "
f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}"
)

def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
num_tokens = topk_ids.shape[0]
self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)

def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None:
num_tokens = mem_indexes.shape[0]
if num_tokens == 0:
return

slot = microbatch_index % self.num_slots
stream = self.flush_streams[slot]
event = self.flush_events[slot]

stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(stream):
cpu_indexes = mem_indexes.cpu()
self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu()
event.record()

def sync_events(self) -> None:
"""Synchronize all flush events. Call once before batch extraction."""
for event in self.flush_events:
event.synchronize()

def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
self.sync_events()
return self.cpu_buffer[:, mem_indexes, :].numpy()

def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray:
return self.cpu_buffer[:, mem_indexes, :].numpy()


g_routing_capture_manager: Optional[RoutingCaptureManager] = None


def create_routing_capture_manager(
num_moe_layers: int,
topk: int,
num_experts: int,
batch_max_tokens: int,
kv_cache_size: int,
enable_overlap: bool = False,
) -> None:
global g_routing_capture_manager
assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
g_routing_capture_manager = RoutingCaptureManager(
num_moe_layers=num_moe_layers,
topk=topk,
num_experts=num_experts,
batch_max_tokens=batch_max_tokens,
kv_cache_size=kv_cache_size,
enable_overlap=enable_overlap,
)


def init_routing_capture(model, num_moe_layers: int) -> None:
if get_current_rank_in_dp() != 0:
# Skipping routing capture initialization on non-zero rank
return

if num_moe_layers == 0:
logger.warning(
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
)
return

num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
topk = model.config.get("num_experts_per_tok", 0)
assert num_experts > 0 and topk > 0
enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False)

logger.info(
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
)

create_routing_capture_manager(
num_moe_layers=num_moe_layers,
topk=topk,
num_experts=num_experts,
batch_max_tokens=model.max_total_token_num,
kv_cache_size=model.mem_manager.size + 1,
enable_overlap=enable_overlap,
)

shm = get_routing_config_shm()
shm.arr[0] = num_moe_layers
shm.arr[1] = topk
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def _moe_ffn(
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
microbatch_index=infer_state.microbatch_index,
)

if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
Expand Down Expand Up @@ -339,6 +340,7 @@ def _moe_ffn_edp(
topk_group=self.topk_group,
num_expert_group=self.n_group,
is_prefill=infer_state.is_prefill,
microbatch_index=infer_state.microbatch_index,
)

if self.n_shared_experts is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def _init_moe(self):
# == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。
if self.num_fused_shared_experts == 0:
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
first_moe = self.network_config_["first_k_dense_replace"]
freq = self.network_config_.get("moe_layer_freq", 1)
moe_layer_index = (self.layer_num_ - first_moe) // freq
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
Expand All @@ -256,6 +259,7 @@ def _init_moe(self):
num_fused_shared_experts=self.num_fused_shared_experts,
layer_num=self.layer_num_,
network_config=self.network_config_,
moe_layer_index=moe_layer_index,
)

def _init_ffn(self):
Expand Down
4 changes: 4 additions & 0 deletions lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
from lightllm.distributed.communication_op import dist_group_manager
Expand Down Expand Up @@ -49,6 +50,9 @@ def _init_some_value(self):
def _init_custom(self):
self._init_to_get_yarn_rotary()
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
if self.args.enable_return_routed_experts:
num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
init_routing_capture(self, num_moe_layers)

def _verify_params(self):
return super()._verify_params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
microbatch_index=infer_state.microbatch_index,
)
return hidden_states.view(num_tokens, hidden_dim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _init_moe(self):
num_fused_shared_experts=0,
layer_num=self.layer_num_,
network_config=self.network_config_,
moe_layer_index=self.layer_num_,
)

def _init_weight_names(self):
Expand Down
6 changes: 6 additions & 0 deletions lightllm/models/gpt_oss/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.models.registry import ModelRegistry
from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class
Expand All @@ -28,3 +29,8 @@ def _init_att_backend(self):
self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])(
model=self
)

def _init_custom(self):
super()._init_custom()
num_moe_layers = len(self.trans_layers_weight)
init_routing_capture(self, num_moe_layers)
10 changes: 6 additions & 4 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ def _init_custom(self):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
return

if "rope_type" in rope_scaling:
elif "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
else:
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
super()._init_custom()

def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling):
if scaling_type == "default" or "mrope_section" in rope_scaling:
self._init_to_get_rotary()
elif scaling_type == "yarn":
Expand All @@ -96,7 +99,6 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return

def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
Expand Down
Loading