diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..5437c24363 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -33,8 +33,13 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend +from . import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -90,6 +95,7 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self.is_mtp_mode = self.args.mtp_mode in [ "vanilla_with_att", "eagle_with_att", @@ -103,18 +109,22 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - self._init_weights() - self._init_req_manager() - self._init_mem_manager() + enable_weight_cpu_backup = self.args.enable_weight_cpu_backup + with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=enable_weight_cpu_backup): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_req_manager() + self._init_mem_manager() + # 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 self.req_manager.mem_manager = self.mem_manager - self._check_mem_size() self._init_infer_layer() self._init_some_value() self._init_custom() - self._load_hf_weights() + self._init_routing_capture() + self.load_weights(self.weight_dict) self._init_att_backend() self._init_att_backend1() @@ -172,13 +182,14 @@ def _init_weights(self, start_layer_index=0): ] return - def _load_hf_weights(self): + def load_weights(self, weight_dict: dict): + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( - self.data_type, + data_type=self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, + weight_dict=weight_dict, ) self.pre_post_weight.verify_load() [weight.verify_load() for weight in self.trans_layers_weight] @@ -289,6 +300,17 @@ def _init_prefill_cuda_graph(self): def _init_custom(self): pass + def _init_routing_capture(self): + if not self.args.enable_return_routed_experts: + return + if _routing_mgr.g_routing_capture_manager is not None: + # MTP draft models share the main model process and KV cache, so they + # should reuse the routing capture manager initialized by the main model. + logger.info("RoutingCaptureManager already initialized, skip routing capture init.") + return + _routing_mgr.init_routing_capture(self) + return + @torch.no_grad() def forward(self, model_input: ModelInput): model_input.to_cuda() @@ -333,6 +355,11 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.mem_index = model_input.mem_indexes infer_state.microbatch_index = microbatch_index infer_state.dist_group = dist_group_manager.get_group(microbatch_index) + mgr = _routing_mgr.g_routing_capture_manager + if mgr is not None: + # Build a callback that records each MoE layer's top-k experts into + # the routing buffer at this forward's KV-cache positions. + infer_state.make_routing_capture_callback = mgr.make_capture_callback_factory(infer_state.mem_index) # 特殊模型,特殊模式的特定变量初始化操作。 infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens @@ -1031,6 +1058,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -1165,6 +1193,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..5e8036ee81 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -8,6 +8,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -94,7 +99,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): if param_name not in pure_para_set: delattr(infer_state, param_name) - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() @@ -128,7 +133,7 @@ def _capture_decode_overlap( if para_name not in pure_para_set1: delattr(infer_state1, para_name) - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..e09452b5ab 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -100,6 +100,9 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None + # Optional hook for recording MoE routing top-k ids during forward. + self.make_routing_capture_callback = None + def init_some_extra_state(self, model): if self.is_prefill: ( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..1e7f94b314 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -124,6 +124,12 @@ def _init_parallel_params(self): self.expert_idx_to_local_idx = {expert_idx: i for (i, expert_idx) in enumerate(self.local_expert_ids)} self.rexpert_idx_to_local_idx = {} + def _make_routing_capture_callback(self, infer_state): + make_routing_capture_callback = getattr(infer_state, "make_routing_capture_callback", None) + if make_routing_capture_callback is None: + return None + return make_routing_capture_callback(self.layer_num_) + def experts( self, input_tensor: torch.Tensor, @@ -134,8 +140,9 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + infer_state=None, ) -> torch.Tensor: - """Backward compatible method that routes to platform-specific implementation.""" + routing_capture_callback = self._make_routing_capture_callback(infer_state) return self.fuse_moe_impl( input_tensor=input_tensor, router_logits=router_logits, @@ -149,6 +156,7 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + routing_capture_callback=routing_capture_callback, per_expert_scale=self.per_expert_scale, ) @@ -317,6 +325,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -339,6 +348,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): + # for merged weights + self._load_merge_weight(weights) # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: @@ -363,6 +374,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -372,6 +384,19 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight + if w1_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) + def _load_expert_scale( self, expert_idx: int, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..2e82571f51 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -144,10 +144,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + infer_state=None, ): topk_weights, topk_ids = self._router(router_logits, top_k) + routing_capture_callback = self._make_routing_capture_callback(infer_state) + if routing_capture_callback is not None: + routing_capture_callback(topk_ids) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..443a271cd4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -1,10 +1,10 @@ import torch from abc import abstractmethod +from typing import Callable, Optional from lightllm.common.quantization.quantize_method import ( WeightPack, QuantizationMethod, ) -from typing import Optional from lightllm.utils.dist_utils import ( get_global_rank, get_global_world_size, @@ -62,6 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + routing_capture_callback: Optional[Callable[[torch.Tensor], None]] = None, per_expert_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..b5d231f58e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import Callable, Optional from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl @@ -62,6 +62,7 @@ def _select_experts( topk_weights.mul_(self.routed_scaling_factor) if per_expert_scale is not None: topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) + routed_topk_ids = topk_ids if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( @@ -83,7 +84,7 @@ def _select_experts( topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - return topk_weights, topk_ids + return topk_weights, topk_ids, routed_topk_ids def _fused_experts( self, @@ -128,9 +129,10 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + routing_capture_callback: Optional[Callable[[torch.Tensor], None]] = None, per_expert_scale: Optional[torch.Tensor] = None, ): - topk_weights, topk_ids = self._select_experts( + selected_experts = self._select_experts( input_tensor=input_tensor, router_logits=router_logits, correction_bias=correction_bias, @@ -142,6 +144,15 @@ def __call__( scoring_func=scoring_func, per_expert_scale=per_expert_scale, ) + if len(selected_experts) == 2: + topk_weights, topk_ids = selected_experts + routed_topk_ids = topk_ids + else: + topk_weights, topk_ids, routed_topk_ids = selected_experts + + if routing_capture_callback is not None: + routing_capture_callback(routed_topk_ids) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..067c1c8ca9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -28,6 +28,10 @@ def _get_slice_start_end(self, size: int) -> Tuple[int, int]: end = start + tp_size return start, end + def _assert_weight_ndim(self, tensor: torch.Tensor) -> None: + # 2D: 普通 linear (out, in); 3D: MoE 合并权重 (num_experts, out, in)。 + assert tensor.dim() in (2, 3), f"expect weight ndim in (2, 3), got shape {tuple(tensor.shape)}" + class SliceMixinTpl(SliceMixinBase): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): @@ -46,18 +50,20 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten raise NotImplementedError("slice_weight_zero_point must implement this method") -# 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 默认 weight 的 shape 末两维是 (out, in),普通 linear 是 2D (out, in), +# MoE 合并权重则是 3D (num_experts, out, in),统一通过 `...` 处理任意前导维。 +# 约定 row-wise 沿着 out 维(倒数第二维)切分,col-wise 沿着 in 维(最后一维)切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight) assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -74,18 +80,20 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight_scale) assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight_zero_point) assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -93,11 +101,12 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight) assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -108,18 +117,20 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight_scale) assert ( - weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + weight_scale.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + self._assert_weight_ndim(weight_zero_point) assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 0000000000..a154415b59 --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,276 @@ +import atexit +import json +import os +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Dict, List, Optional, Tuple +from lightllm.common.basemodel.triton_kernel.routing_capture import scatter_routing_topk_to_cpu +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.uint8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def _get_model_text_config(config: dict) -> dict: + return config.get("text_config", config) + + +def _get_num_moe_layers_from_config(config: dict) -> int: + num_layers = config.get("num_hidden_layers", config.get("n_layer", config.get("num_layers", 0))) + num_experts = config.get("n_routed_experts", config.get("num_experts", config.get("num_local_experts", 0))) + if num_layers <= 0 or not num_experts: + return 0 + + if "first_k_dense_replace" in config: + first_k_dense_replace = config.get("first_k_dense_replace", 0) + moe_layer_freq = config.get("moe_layer_freq", 1) + return sum( + 1 + for layer_num in range(num_layers) + if layer_num >= first_k_dense_replace and layer_num % moe_layer_freq == 0 + ) + + if "mlp_only_layers" in config or "decoder_sparse_step" in config: + mlp_only_layers = set(config.get("mlp_only_layers", [])) + decoder_sparse_step = config.get("decoder_sparse_step", 1) + return sum( + 1 + for layer_num in range(num_layers) + if layer_num not in mlp_only_layers and (layer_num + 1) % decoder_sparse_step == 0 + ) + + if config.get("enable_moe_block", False): + return num_layers + + return num_layers + + +def get_routing_config_from_model_dir(model_dir: str) -> Optional[Tuple[int, int, int]]: + with open(os.path.join(model_dir, "config.json"), "r") as json_file: + config = _get_model_text_config(json.load(json_file)) + + num_moe_layers = _get_num_moe_layers_from_config(config) + topk = config.get("num_experts_per_tok", config.get("top_k_experts", 0)) + num_experts = config.get("n_routed_experts", config.get("num_experts", config.get("num_local_experts", 0))) + if num_moe_layers <= 0 or topk <= 0 or not num_experts: + return None + + dtype_id = 1 if num_experts <= 256 else 2 + return num_moe_layers, topk, dtype_id + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + layer_num_to_moe_index: Optional[Dict[int, int]] = None, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + self.max_capture_tokens = max_capture_tokens + self.layer_num_to_moe_index = layer_num_to_moe_index or {i: i for i in range(num_moe_layers)} + + self.dtype = torch.uint8 if num_experts <= 256 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 + + # Shape: (kv_cache_size, num_moe_layers, topk). Pinned CPU memory saves GPU memory + # while allowing the Triton scatter kernel to write without a synchronous D2H copy. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=self.dtype, + device="cpu", + pin_memory=True, + ) + self.routing_buffer_ptr = torch.tensor([self.routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda") + + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.uint8 if self.dtype == torch.uint8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.uint8 else 2 + + def make_capture_callback_factory(self, mem_indexes: torch.Tensor): + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + + def make_capture_callback(layer_num: int): + routing_layer_index = self.layer_num_to_moe_index.get(layer_num) + if routing_layer_index is None: + return None + + def capture_callback(topk_ids: torch.Tensor) -> None: + self.capture(routing_layer_index=routing_layer_index, topk_ids=topk_ids, mem_indexes=mem_indexes) + + return capture_callback + + return make_capture_callback + + def capture(self, routing_layer_index: int, topk_ids: torch.Tensor, mem_indexes: torch.Tensor) -> None: + assert topk_ids.dim() == 2 + assert topk_ids.shape[1] == self.topk + assert mem_indexes.shape[0] >= topk_ids.shape[0] + scatter_routing_topk_to_cpu( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=self.routing_buffer_ptr, + moe_layer_index=routing_layer_index, + num_moe_layers=self.num_moe_layers, + topk=self.topk, + dtype_id=self.dtype_id, + ) + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + torch.cuda.current_stream().synchronize() + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[cpu_indexes, :, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + layer_num_to_moe_index: Optional[Dict[int, int]] = None, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + layer_num_to_moe_index=layer_num_to_moe_index, + ) + + +def _get_moe_layer_nums(model) -> List[int]: + moe_layer_nums = [] + for layer_weight in getattr(model, "trans_layers_weight", []): + is_moe = getattr(layer_weight, "is_moe", None) + if is_moe is None: + is_moe = hasattr(layer_weight, "experts") + if is_moe: + moe_layer_nums.append(layer_weight.layer_num_) + return moe_layer_nums + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: Optional[int] = None) -> None: + moe_layer_nums = _get_moe_layer_nums(model) + if num_moe_layers is None: + num_moe_layers = len(moe_layer_nums) + elif moe_layer_nums: + assert num_moe_layers == len(moe_layer_nums) + else: + moe_layer_nums = list(range(num_moe_layers)) + layer_num_to_moe_index = {layer_num: moe_index for moe_index, layer_num in enumerate(moe_layer_nums)} + + dp_rank = get_current_rank_in_dp() + logger.info( + f"init_routing_capture called: num_moe_layers={num_moe_layers}, " + f"moe_layer_nums={moe_layer_nums}, dp_rank={dp_rank}" + ) + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get( + "n_routed_experts", + model.config.get("num_experts", model.config.get("num_local_experts", 0)), + ) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + layer_num_to_moe_index=layer_num_to_moe_index, + ) + + logger.info( + f"Routing capture config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={g_routing_capture_manager.dtype_id}, max_tokens={args.max_req_total_len}" + ) + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/common/basemodel/triton_kernel/routing_capture.py b/lightllm/common/basemodel/triton_kernel/routing_capture.py new file mode 100644 index 0000000000..d0fa822058 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/routing_capture.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _scatter_routing_topk_to_cpu( + topk_ids, + mem_indexes, + routing_buffer_ptr, + total_size, + moe_layer_index: tl.constexpr, + layer_topk_size: tl.constexpr, + topk: tl.constexpr, + dtype_id: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < total_size + + token_offsets = offsets // topk + topk_offsets = offsets - token_offsets * topk + mem_index = tl.load(mem_indexes + token_offsets, mask=mask, other=-1).to(tl.int64) + data = tl.load(topk_ids + offsets, mask=mask, other=0) + + dst_offsets = mem_index * layer_topk_size + moe_layer_index * topk + topk_offsets + if dtype_id == 1: + dst_ptr = tl.load(routing_buffer_ptr).to(tl.pointer_type(tl.uint8)) + tl.store(dst_ptr + dst_offsets, data.to(tl.uint8), mask=mask) + else: + dst_ptr = tl.load(routing_buffer_ptr).to(tl.pointer_type(tl.int16)) + tl.store(dst_ptr + dst_offsets, data.to(tl.int16), mask=mask) + + +def scatter_routing_topk_to_cpu( + topk_ids: torch.Tensor, + mem_indexes: torch.Tensor, + routing_buffer_ptr: torch.Tensor, + moe_layer_index: int, + num_moe_layers: int, + topk: int, + dtype_id: int, +): + assert topk_ids.is_cuda + assert mem_indexes.is_cuda + assert mem_indexes.is_contiguous() + assert routing_buffer_ptr.is_cuda + assert routing_buffer_ptr.dtype == torch.uint64 + assert routing_buffer_ptr.numel() == 1 + assert topk_ids.dim() == 2 + assert topk_ids.shape[1] == topk + assert topk_ids.is_contiguous() + assert 0 <= moe_layer_index < num_moe_layers + + num_tokens = topk_ids.shape[0] + layer_topk_size = num_moe_layers * topk + total_size = num_tokens * topk + if total_size == 0: + return + + BLOCK = 1024 + grid = (triton.cdiv(total_size, BLOCK),) + _scatter_routing_topk_to_cpu[grid]( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=routing_buffer_ptr, + total_size=total_size, + moe_layer_index=moe_layer_index, + layer_topk_size=layer_topk_size, + topk=topk, + dtype_id=dtype_id, + BLOCK=BLOCK, + ) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..c75c871c72 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..14026090e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..939c939523 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..13ba4ba8e5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ee316f610b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e027701092 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ddda23d257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..560ca6c09d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0713de7996 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e950ff0954 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7f479b8382 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b3051c6584 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..fdb3212216 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a94e669353 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..441421fd5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..864d1d3f18 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bcf56e01f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 128, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ba1dc8a75d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "3072": { + "BLOCK_SIZE": 2048, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..6f109e1c6e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 32768, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..198a196dfb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..537c7a90eb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..9a6dcb6fbf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4096": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..df501847ec --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..4cc6453d12 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -11,7 +11,7 @@ from frozendict import frozendict from lightllm.utils.device_utils import get_current_device_name from lightllm.utils.log_utils import init_logger -from typing import Callable, Optional, Union, List +from typing import Callable, List from lightllm.utils.envs_utils import get_triton_autotune_level from lightllm.common.kernel_config import KernelConfigs from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node @@ -106,14 +106,6 @@ def __init__( self.configs_gen_func = configs_gen_func self.kernel_name = kernel_name - self.cache_dir = os.path.join( - Path(__file__).parent, - "autotune_kernel_configs", - get_triton_version(), - get_current_device_name(), - self.kernel_name, - ) - os.makedirs(self.cache_dir, exist_ok=True) self.fn = fn self.static_key_func = static_key_func self.run_key_func = run_key_func @@ -209,6 +201,25 @@ def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + @property + def cache_dir(self) -> str: + if not hasattr(self, "_cache_dir"): + device_name = get_current_device_name() + if device_name is None: + raise RuntimeError( + f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device " + f"to resolve its cache directory, but torch.cuda.is_available() is False." + ) + self._cache_dir = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + get_triton_version(), + device_name, + self.kernel_name, + ) + os.makedirs(self._cache_dir, exist_ok=True) + return self._cache_dir + def _try_load_cache(self, static_key): if static_key in self.cached_configs: return False diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index be819c94a0..dae79cc8a6 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -232,6 +232,7 @@ def _moe_ffn_tp( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + infer_state=infer_state, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -259,6 +260,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + infer_state=infer_state, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 015b526fbc..2f0c01dbf6 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -300,6 +300,7 @@ def _ffn_moe(self, input, router_logits, infer_state: InferStateInfo, layer_weig topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + infer_state=infer_state, ) moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state) return moe_out diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index b27ea8fd2d..490d2dc4c5 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -52,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + infer_state=infer_state, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) return self._tpsp_reduce(input=hidden_states, infer_state=infer_state) diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1de..0000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 0cf651598a..8134dc266d 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -21,25 +18,14 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + infer_state=infer_state, ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - ffn2_out = fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, - ) - return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..c94135573b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -12,6 +12,7 @@ from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel import os +from typing import Union, List # Warp of the origal tokenizer class QWen2VLTokenizer(BaseMultiModalTokenizer): @@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem): def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - - origin_ids = self.tokenizer.encode(prompt) + def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs): + if isinstance(prompt, str): + origin_ids = self.tokenizer.encode(prompt) + elif isinstance(prompt, list): + origin_ids = prompt + else: + raise ValueError(f"Unsupported prompt type: {type(prompt)}") # -> origin_ids = [token for token in origin_ids if token != self.image_token_id] diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 8879aa2d27..7edfd5a6f9 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -86,6 +86,7 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + infer_state=infer_state, ) return hidden_states.view(num_tokens, hidden_dim) @@ -105,6 +106,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + infer_state=infer_state, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index e4d80e6ff9..60bf0e6b76 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -135,6 +135,7 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + infer_state=infer_state, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) hidden_states.add_(shared_expert_out) @@ -156,6 +157,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + infer_state=infer_state, ) ep_output = ep_output.view(token_num, hidden_dim) ep_output.add_(shared_expert_out) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..1ef6375c3e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", @@ -80,6 +79,12 @@ def make_argument_parser() -> argparse.ArgumentParser: when a llm infer node start to set this params, the visual infer module will start a proxy module use config server to find remote vit infer nodes to infer img""", ) + parser.add_argument( + "--rl_rpyc_port", + type=int, + default=None, + help="The router RL control RPyC port. If unset, LightLLM will allocate one automatically.", + ) parser.add_argument( "--pd_kv_page_num", type=int, @@ -256,6 +261,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) + parser.add_argument( + "--lightllm_instance_id", + type=int, + default=0, + help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.", + ) parser.add_argument( "--use_config_server_to_init_nccl", action="store_true", @@ -769,6 +780,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") parser.add_argument( "--disk_cache_dir", type=str, @@ -842,6 +859,12 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) parser.add_argument( "--enable_profiling", type=str, @@ -858,3 +881,7 @@ def make_argument_parser() -> argparse.ArgumentParser: A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""", ) return parser + + +def make_argument_parser() -> argparse.ArgumentParser: + return add_cli_args(argparse.ArgumentParser()) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 270e2a8cfd..3393fec4a9 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -34,7 +34,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -50,6 +50,7 @@ from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -61,6 +62,16 @@ ModelCard, ModelListResponse, ) +from .io_struct import ( + AbortReq, + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + UpdateWeightsFromIPCReq, + GeneralModelToHttpRpcRsp, +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -188,6 +199,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -411,6 +438,91 @@ async def metrics() -> Response: return response +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + success, msg = await g_objs.httpserver_manager.abort_request(request) + if not success: + return create_error_response(HTTPStatus.REQUEST_TIMEOUT, msg, err_type="AbortRequestTimeout") + return Response(status_code=200) + except Exception as e: + logger.error("abort_request error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + + +@app.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + + +@app.post("/update_weights_from_ipc") +async def update_weights_from_ipc(request: UpdateWeightsFromIPCReq, raw_request: Request): + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_ipc) + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) + + +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 39a5808aab..28d57ccdc4 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if return_routed_experts and routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 6e04d5d47e..5306ecb698 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,11 +1,22 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start + +def launch_server(args: StartArgs): + from .api_start import pd_master_start, normal_or_p_d_start, config_server_start, visual_only_start + + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e if args.run_mode == "pd_master": pd_master_start(args) @@ -15,3 +26,13 @@ visual_only_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cf431d650..42d32c7ec3 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -18,6 +19,7 @@ from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.config_utils import ( has_audio_module, has_vision_module, @@ -60,9 +62,31 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") if http_server_process: @@ -70,10 +94,13 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) - args: StartArgs = args + +def _launch_subprocesses(args: StartArgs): + + _set_envs_and_config(args) auto_set_max_req_total_len(args) set_unique_server_name(args) @@ -143,12 +170,6 @@ def normal_or_p_d_start(args): check_recommended_shm_size(args) assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] - # 确保单机上多实列不冲突 - if args.zmq_mode == "ipc:///tmp/": - zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" - args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 - args.zmq_mode = zmq_mode - logger.info(f"zmq mode head: {args.zmq_mode}") logger.info(f"use tgi api: {args.use_tgi_api}") @@ -210,12 +231,16 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - assert args.mtp_draft_model_dir is not None + if args.mtp_draft_model_dir is None: + args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -336,8 +361,12 @@ def normal_or_p_d_start(args): assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] already_uesd_ports = [args.port] - if args.nccl_port is not None: + # nccl_port 只在 rank 0 上 bind(TCPStore listener),其他 rank 是 connect, + # 不应该把它加入端口锁定列表,否则单机多节点 tp 测试会冲突。 + if args.nccl_port is not None and args.node_rank == 0: already_uesd_ports.append(args.nccl_port) + if args.rl_rpyc_port is not None: + already_uesd_ports.append(args.rl_rpyc_port) if args.visual_nccl_ports is not None: already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) if not args.disable_audio and args.audio_nccl_ports is not None: @@ -350,9 +379,12 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + num=11 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + instance_id=args.lightllm_instance_id, used_ports=already_uesd_ports, ) + auto_ports_locker = PortLocker(can_use_ports) + auto_ports_locker.lock_port() logger.info(f"alloced ports: {can_use_ports}") ( nccl_port, @@ -365,8 +397,9 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:10] - can_use_ports = can_use_ports[10:] + rl_rpyc_port, + ) = can_use_ports[0:11] + can_use_ports = can_use_ports[11:] if args.visual_nccl_ports is None: args.visual_nccl_ports = can_use_ports[: args.visual_dp] @@ -383,6 +416,18 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port + if args.rl_rpyc_port is None: + args.rl_rpyc_port = rl_rpyc_port + + set_unique_server_name(args) + + # 确保单机上多实列不冲突 + if args.zmq_mode == "ipc:///tmp/": + zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" + args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 + args.zmq_mode = zmq_mode + logger.info(f"zmq mode head: {args.zmq_mode}") + args.router_port = router_port args.router_profiler_port = router_profiler_port args.detokenization_port = detokenization_port @@ -415,6 +460,7 @@ def normal_or_p_d_start(args): logger.info(f"all start args:{args}") ports_locker.release_port() + auto_ports_locker.release_port() if args.enable_multimodal: process_manager.start_submodule_processes( @@ -486,6 +532,13 @@ def normal_or_p_d_start(args): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 Hypercorn command = [ "hypercorn", @@ -521,7 +574,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -540,16 +593,16 @@ def pd_master_start(args): logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port( - num=1, - used_ports=[ - args.port, - ], + num=1, used_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id ) + auto_ports_locker = PortLocker(can_use_ports) + auto_ports_locker.lock_port() metric_port = can_use_ports[0] args.metric_port = metric_port set_env_start_args(args) + auto_ports_locker.release_port() process_manager.start_submodule_processes( start_funcs=[ @@ -599,7 +652,10 @@ def visual_only_start(args): can_use_ports = alloc_can_use_network_port( num=5 + args.visual_dp * args.visual_tp + args.visual_dp, used_ports=already_uesd_ports, + instance_id=args.lightllm_instance_id, ) + auto_ports_locker = PortLocker(can_use_ports) + auto_ports_locker.lock_port() if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -620,6 +676,7 @@ def visual_only_start(args): logger.info(f"all start args:{args}") set_env_start_args(args) + auto_ports_locker.release_port() from .visualserver.visual_only_manager import start_visual_process diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index cbc63c898d..2514d9dacb 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -111,13 +111,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7f2b697091..90fede4cf4 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -3,6 +3,7 @@ import ctypes import numpy as np import time +from multiprocessing import shared_memory from .sampling_params import SamplingParams from .out_token_circlequeue import CircularQueue from .shm_array import ShmArray @@ -14,6 +15,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -25,19 +27,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -50,6 +53,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -277,6 +282,43 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def get_routing_data_shm_name(self): + service_uni_name = get_unique_server_name() + return f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.uint8): + """Create routing SHM at actual size.""" + self.shm_routing_data = ShmArray( + self.get_routing_data_shm_name(), (num_tokens, num_moe_layers, topk), dtype=np_dtype + ) + self.shm_routing_data.create_shm() + return + + def link_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.uint8): + """Link routing SHM at actual size.""" + if num_moe_layers <= 0 or num_tokens <= 0 or topk <= 0: + return None + shm_routing_data = ShmArray( + self.get_routing_data_shm_name(), (num_tokens, num_moe_layers, topk), dtype=np_dtype + ) + shm_routing_data.link_shm() + self.shm_routing_data = shm_routing_data + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Close and unlink routing SHM (on-demand, no longer pooled).""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + else: + try: + shm = shared_memory.SharedMemory(name=self.get_routing_data_shm_name()) + shm.close() + shm.unlink() + except FileNotFoundError: + pass + return + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() @@ -297,9 +339,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c39559f5f6..9cb52f02b7 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -396,15 +396,18 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - if cls._repetition_penalty is None: - cls._repetition_penalty = 1.0 - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) except: pass diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index c5ad512c6b..74d64b6c5e 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,6 +26,13 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return + def detach_shm(self): + """Close handle without unlinking (SHM persists for reuse).""" + if self.shm is not None: + self.shm.close() + self.shm = None + self.arr = None + def close_shm(self): if self.shm is not None: self.shm.close() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40c8028158..f6eacbfc8b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass @@ -10,31 +10,50 @@ class StartArgs: default="normal", metadata={"choices": ["normal", "pd_master", "prefill", "decode", "config_server", "visual_only"]}, ) + performance_mode: str = field(default=None, metadata={"choices": ["personal"]}) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) config_server_visual_redis_port: int = field(default=None) afs_image_embed_dir: str = field(default=None) afs_embed_capacity: int = field(default=250000) - select_p_d_node_strategy: str = field(default=None) + rl_rpyc_port: int = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") + model_owner: Optional[str] = field(default=None) model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.8) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( default=None, - metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, + metadata={ + "choices": [ + "qwen25", + "llama3", + "mistral", + "deepseekv3", + "qwen", + "deepseekv31", + "deepseekv32", + "glm47", + "kimi_k2", + "qwen3_coder", + ] + }, ) reasoning_parser: Optional[str] = field( default=None, @@ -53,11 +72,12 @@ class StartArgs: "step3", "nano_v3", "interns1", + "gemma4", ] }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=512) + running_max_req_size: int = field(default=256) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) @@ -66,12 +86,13 @@ class StartArgs: max_req_total_len: Optional[int] = field(default=None) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) + lightllm_instance_id: int = field(default=0) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) detail_log: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) - router_token_ratio: float = field(default=0.0) + router_token_ratio: float = field(default=None) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) enable_prefill_decode_mixed: bool = field(default=False) @@ -80,7 +101,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) disable_vision: Optional[bool] = field(default=None) @@ -109,12 +130,12 @@ class StartArgs: ) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: List[int] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) @@ -132,19 +153,19 @@ class StartArgs: graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) - quant_type: Optional[str] = field(default=None) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]}) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) + expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]}) llm_prefill_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) vit_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) llm_kv_type: str = field( default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} @@ -165,8 +186,6 @@ class StartArgs: "eagle_with_att", "vanilla_no_att", "eagle_no_att", - "qwen3next_vanilla", - "qwen3next_eagle", None, ] }, @@ -179,7 +198,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -195,6 +214,27 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) + # multi_modal + enable_multimodal_audio: bool = field(default=False) + + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) + + weight_version: str = "default" # hybrid attention model (Qwen3Next) linear_att_hash_page_size: int = field(default=512) diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8effc..c77379986c 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0f1b873111..2b1c0a0afa 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -3,13 +3,14 @@ import zmq.asyncio import asyncio import uvloop -import rpyc import socket +import rpyc import time import copy import hashlib import datetime import pickle +import base64 from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -29,8 +30,22 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_from_model_dir, routing_dtype_id_to_np from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + FlushCacheReq, + ReleaseMemoryReq, + ResumeMemoryReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + UpdateWeightsFromIPCReq, + GeneralModelToHttpRpcRsp, +) +from .rl_controller import HttpRlController from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name @@ -56,7 +71,6 @@ def __init__( self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) self._run_reqs_count_lock = AsyncLock(self._shm_lock_pool.get_lock_context(1)) self.node_rank = args.node_rank - self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 self.is_multinode_tp_master = args.dp == 1 and args.nnodes > 1 and args.node_rank == 0 self.is_multinode_tp_slave = args.dp == 1 and args.nnodes > 1 and args.node_rank > 0 @@ -75,7 +89,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -122,6 +136,13 @@ def __init__( # Timemark of the latest successful inference, used by passive /health checks. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + self._routing_config = ( + get_routing_config_from_model_dir(args.model_dir) + if args.enable_return_routed_experts and args.node_rank == 0 + else None + ) + + self.rl_controller = HttpRlController(self) self.run_reqs_count_mark = SharedInt(f"{get_unique_server_name()}_run_reqs_count_mark") self.run_reqs_count_mark.set_value(0) @@ -335,6 +356,11 @@ async def generate( self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() + 1) try: + if await self.rl_controller.wait_until_generation_allowed(group_request_id): + for output in self._build_aborted_generation_outputs(group_request_id, sampling_params): + yield output + return + original_multimodal_params = None if self.is_multinode_tp_master: original_multimodal_params = copy.deepcopy(multimodal_params) @@ -349,6 +375,7 @@ async def generate( # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) self._log_stage_timing( @@ -486,6 +513,15 @@ async def generate( self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() - 1) return + def _build_aborted_generation_outputs(self, group_request_id: int, sampling_params: SamplingParams): + # Paused requests never enter router queues, so synthesize an aborted empty + # result that both stream and non-stream APIs can consume normally. + finish_status = FinishStatus() + finish_status.set_status(FinishStatus.FINISHED_ABORTED) + metadata = {"prompt_tokens": 0, "count_output_tokens": 0} + for i in range(sampling_params.n): + yield group_request_id + i, "", metadata, finish_status + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -500,6 +536,43 @@ def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple return image_tokens, audio_tokens + def _read_routed_experts_metadata(self, req: Req): + if self._routing_config is None: + return None + num_moe_layers, topk, dtype_id = self._routing_config + num_tokens = req.input_len + req.shm_cur_output_len + if num_tokens <= 0: + return None + + try: + routing_data = req.link_routing_data_shm_array( + num_moe_layers, num_tokens, topk, np_dtype=routing_dtype_id_to_np(dtype_id) + ) + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {req.request_id}: {e}") + return None + + async def _wait_and_read_routed_experts_metadata(self, req: Req, timeout: float = 60.0): + if not await self.rl_controller.wait_until_can_released_mark(req, timeout=timeout): + return None + + start_time = time.time() + while True: + routing_meta = self._read_routed_experts_metadata(req) + if routing_meta is not None: + return routing_meta + if time.time() - start_time > timeout: + logger.warning(f"wait routing data shm timeout, req_id={req.request_id}, timeout={timeout}s") + return None + await asyncio.sleep(0.005) + async def _log_req_header(self, request_headers, group_request_id: int): x_request_id = request_headers.get("X-Request-Id", "") @@ -562,6 +635,18 @@ async def _encode( else: raise ValueError("prompt List[int] format contain id > vocab_size") else: + if self.enable_multimodal and self.pd_mode.is_P_or_NORMAL(): + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert not self.args.disable_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + return self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + ) return prompt else: raise ValueError(f"prompt format error, get type{type(prompt)}") @@ -688,7 +773,7 @@ async def _wait_to_token_package( if req_status.aborted: raise Exception(f"req_id {group_request_id} aborted notifyed by other module") - if not self.disable_abort and request is not None and await request.is_disconnected(): + if request is not None and await request.is_disconnected(): await self.abort(group_request_id) raise ClientDisconnected( group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected" @@ -800,6 +885,9 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq) -> Tuple[bool, str]: + return await self.rl_controller.abort_request(request) + def _get_router_profiler_client(self): router_profiler_client = getattr(self, "router_profiler_client", None) if router_profiler_client is None or getattr(router_profiler_client, "closed", False): @@ -839,6 +927,10 @@ async def recycle_resource_loop(self): self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) _is_aborted = False for req in req_status.group_req_objs.shm_req_objs: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") _is_aborted = _is_aborted or req.is_aborted logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}") await self.shm_req_manager.async_put_back_req_obj(req) @@ -931,6 +1023,11 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) + if self._routing_config is not None: + routing_meta = await self._wait_and_read_routed_experts_metadata(req) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + token_list.append((req_id, text, metadata, finish_status)) else: break @@ -945,6 +1042,36 @@ async def handle_loop(self): self.recycle_event.set() return + async def pause_generation(self): + return await self.rl_controller.pause_generation() + + async def continue_generation(self): + return await self.rl_controller.continue_generation() + + async def flush_cache(self, request: FlushCacheReq): + return await self.rl_controller.flush_cache(request) + + async def release_memory_occupation(self, request: ReleaseMemoryReq): + return await self.rl_controller.release_memory_occupation(request) + + async def resume_memory_occupation(self, request: ResumeMemoryReq): + return await self.rl_controller.resume_memory_occupation(request) + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.rl_controller.init_weights_update_group(request) + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.rl_controller.destroy_weights_update_group(request) + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + return await self.rl_controller.update_weights_from_distributed(request) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> GeneralModelToHttpRpcRsp: + return await self.rl_controller.update_weights_from_tensor(request) + + async def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq) -> GeneralModelToHttpRpcRsp: + return await self.rl_controller.update_weights_from_ipc(request) + class ReqStatus: def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: diff --git a/lightllm/server/httpserver/rl_controller.py b/lightllm/server/httpserver/rl_controller.py new file mode 100644 index 0000000000..e40ccc916c --- /dev/null +++ b/lightllm/server/httpserver/rl_controller.py @@ -0,0 +1,282 @@ +import asyncio +import socket +import time +from contextlib import asynccontextmanager +from typing import Optional, Tuple + +import rpyc +from rpyc.utils.classic import obtain + +from lightllm.server.io_struct import ( + AbortReq, + FlushCacheReq, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + ReleaseMemoryReq, + ResumeMemoryReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromIPCReq, + UpdateWeightsFromTensorReq, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class _GenerationPauseGate: + """Generation pause gate. + + Pending requests are tracked only while they are waiting at the pause gate. + Once generation is allowed, the request leaves this lightweight path and + continues as a normal request managed by req_id_to_out_inf. + """ + + _RUNNING = 0 + _PAUSED = 1 + + def __init__(self) -> None: + self._state = self._RUNNING + self._pending_request_abort_events = {} + self._cond = asyncio.Condition() + + @asynccontextmanager + async def pause_and_abort_context(self): + async with self._cond: + if self._state != self._RUNNING: + do_abort = False + else: + self._state = self._PAUSED + do_abort = True + yield do_abort + + async def register_pending_request(self, request_id: int): + async with self._cond: + self._pending_request_abort_events[request_id] = asyncio.Event() + + async def unregister_pending_request(self, request_id: int): + async with self._cond: + self._pending_request_abort_events.pop(request_id, None) + + async def wait_until_running_or_aborted(self, request_id: int) -> bool: + """Returns True when the pending request should finish as aborted.""" + async with self._cond: + abort_event = self._pending_request_abort_events.get(request_id) + if abort_event is None: + return False + if abort_event.is_set(): + return True + if self._state == self._RUNNING: + return False + + resume_task = asyncio.create_task(self._wait_until_running()) + abort_task = asyncio.create_task(abort_event.wait()) + done, pending = await asyncio.wait({resume_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + return abort_task in done and abort_event.is_set() + + async def abort_pending_request(self, request_id: int) -> bool: + async with self._cond: + abort_event = self._pending_request_abort_events.get(request_id) + if abort_event is None: + return False + abort_event.set() + return True + + async def abort_all_pending_requests(self): + async with self._cond: + for abort_event in self._pending_request_abort_events.values(): + abort_event.set() + + async def is_pending_request(self, request_id: int) -> bool: + async with self._cond: + return request_id in self._pending_request_abort_events + + async def get_pending_request_count(self) -> int: + async with self._cond: + return len(self._pending_request_abort_events) + + async def _wait_until_running(self) -> None: + async with self._cond: + while self._state != self._RUNNING: + await self._cond.wait() + + async def resume(self) -> None: + async with self._cond: + self._state = self._RUNNING + self._cond.notify_all() + + +class HttpRlController: + def __init__(self, manager) -> None: + self.manager = manager + self.args = manager.args + self._generation_gate = _GenerationPauseGate() + + async def wait_until_generation_allowed(self, request_id: int) -> bool: + await self._generation_gate.register_pending_request(request_id) + try: + return await self._generation_gate.wait_until_running_or_aborted(request_id) + finally: + await self._generation_gate.unregister_pending_request(request_id) + + async def wait_until_can_released_mark(self, req, timeout: float = 60.0) -> bool: + start_time = time.time() + while not req.can_released_mark: + if time.time() - start_time > timeout: + logger.warning(f"wait req can_released_mark timeout, req_id={req.request_id}, timeout={timeout}s") + return False + await asyncio.sleep(0.005) + return True + + async def _wait_for_abort_released( + self, request_id: Optional[int], abort_all: bool, timeout: float = 60.0 + ) -> Tuple[bool, str]: + start_time = time.time() + empty_since = None + while True: + if abort_all: + pending_request_count = await self._generation_gate.get_pending_request_count() + if len(self.manager.req_id_to_out_inf) == 0 and pending_request_count == 0: + empty_since = empty_since or time.time() + if time.time() - empty_since >= 1.0: + return True, "" + else: + empty_since = None + await self._generation_gate.abort_all_pending_requests() + for group_req_id, req_status in list(self.manager.req_id_to_out_inf.items()): + if req_status is not None and any( + not req.is_aborted for req in req_status.group_req_objs.shm_req_objs + ): + await self.manager.abort(group_req_id) + else: + req_status = self.manager.req_id_to_out_inf.get(request_id, None) + if req_status is not None: + if any(not req.is_aborted for req in req_status.group_req_objs.shm_req_objs): + await self.manager.abort(request_id) + elif not await self._generation_gate.is_pending_request(request_id): + return True, "" + + if time.time() - start_time > timeout: + error_msg = ( + f"abort request wait release timeout, request_id={request_id}, abort_all={abort_all}, " + f"timeout={timeout}s" + ) + logger.error(error_msg) + return False, error_msg + + await asyncio.sleep(0.02) + return True, "" + + async def abort_request(self, request: AbortReq) -> Tuple[bool, str]: + request_id = request.request_id + if request.abort_all: + await self._generation_gate.abort_all_pending_requests() + for group_req_id in list(self.manager.req_id_to_out_inf.keys()): + await self.manager.abort(group_req_id) + return await self._wait_for_abort_released(request_id=None, abort_all=True) + + if request_id is None: + return True, "" + + await self._generation_gate.abort_pending_request(request_id) + await self.manager.abort(request_id) + return await self._wait_for_abort_released(request_id=request_id, abort_all=False) + + async def pause_generation(self): + async with self._generation_gate.pause_and_abort_context() as do_abort: + if not do_abort: + return + while True: + success, msg = await self.abort_request(AbortReq(request_id=None, abort_all=True)) + if success: + break + logger.warning(f"pause_generation abort_all still waiting: {msg}") + await asyncio.sleep(1.0) + + async def continue_generation(self): + await self._generation_gate.resume() + + def _call_router_rl_sync(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + from lightllm.utils.retry_utils import retry + + conn = retry(max_attempts=20, wait_time=0.5)(rpyc.connect)( + "localhost", + self.args.rl_rpyc_port, + config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + try: + conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return obtain(conn.root.rl_op(req)) + finally: + try: + conn.close() + except BaseException: + pass + + async def _call_router_rl(self, func_name: str, func_args=None) -> GeneralModelToHttpRpcRsp: + req = GeneralHttpToModelRpcReq(func_name=func_name, func_args=func_args) + try: + return await asyncio.to_thread(self._call_router_rl_sync, req) + except BaseException as e: + logger.exception(f"rl rpyc call {func_name} failed: {e}") + return GeneralModelToHttpRpcRsp( + success=False, + msg=f"rl rpyc call {func_name} error: {e}", + func_name=func_name, + ) + + async def flush_cache(self, request: FlushCacheReq): + return await self._call_router_rl("flush_cache", request) + + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert ( + len(self.manager.req_id_to_out_inf) == 0 + ), "there are still requests running, cannot release memory occupation" + return await self._call_router_rl("release_memory_occupation", request.tags) + + async def resume_memory_occupation(self, request: ResumeMemoryReq): + return await self._call_router_rl("resume_memory_occupation", request.tags) + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self._call_router_rl("init_weights_update_group", request) + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self._call_router_rl("destroy_weights_update_group", request) + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + if request.abort_all_requests: + success, msg = await self.abort_request(AbortReq(abort_all=True)) + if not success: + return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_distributed") + if request.flush_cache: + ret = await self.flush_cache(FlushCacheReq()) + if not ret.success: + return ret + return await self._call_router_rl("update_weights_from_distributed", request) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> GeneralModelToHttpRpcRsp: + if request.abort_all_requests: + success, msg = await self.abort_request(AbortReq(abort_all=True)) + if not success: + return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_tensor") + if request.flush_cache: + ret = await self.flush_cache(FlushCacheReq()) + if not ret.success: + return ret + return await self._call_router_rl("update_weights_from_tensor", request) + + async def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq) -> GeneralModelToHttpRpcRsp: + if request.abort_all_requests: + success, msg = await self.abort_request(AbortReq(abort_all=True)) + if not success: + return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_ipc") + if request.flush_cache: + ret = await self.flush_cache(FlushCacheReq()) + if not ret.success: + return ret + return await self._call_router_rl("update_weights_from_ipc", request) diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 0000000000..1ce42f25ed --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag + + +@dataclass +class AbortReq: + # 外部调用传入,等同内部的 group_req_id + request_id: Optional[int] = None + abort_all: bool = False + + +def _normalize_memory_tags(tags): + if tags is None: + return None + return [tag if isinstance(tag, MemoryTag) else MemoryTag(tag) for tag in tags] + + +@dataclass +class FlushCacheReq: + pass + + +@dataclass +class ReleaseMemoryReq: + tags: Optional[List[MemoryTag]] = None + + def __post_init__(self): + self.tags = _normalize_memory_tags(self.tags) + + +@dataclass +class ResumeMemoryReq: + tags: Optional[List[MemoryTag]] = None + + def __post_init__(self): + self.tags = _normalize_memory_tags(self.tags) + + +@dataclass +class GeneralHttpToModelRpcReq: + func_name: str + func_args: Optional[Any] = None + + +@dataclass +class GeneralModelToHttpRpcRsp: + success: bool + msg: Optional[str] + func_name: str + func_rsp: Optional[Any] = None + + +@dataclass +class InitWeightsUpdateGroupReq: + master_address: str + master_port: int + rank_offset: int + world_size: int + group_name: str = "weight_update_group" + backend: str = "nccl" + + +@dataclass +class DestroyWeightsUpdateGroupReq: + group_name: str = "weight_update_group" + + +@dataclass +class UpdateWeightsFromDistributedReq: + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + group_name: str = "weight_update_group" + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromTensorReq: + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + load_format: Optional[str] = None + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromIPCReq: + ipc_handle: Optional[Union[str, dict]] = None + use_shm: bool = False + flush_cache: bool = True + abort_all_requests: bool = False + weight_version: Optional[str] = None diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 9541e434c8..784762717d 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -77,9 +77,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index f7c099c292..971a3644cc 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -34,6 +34,9 @@ def __init__(self): logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): + if self.args.httpserver_workers == 1: + return + from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs.shm_array import ShmArray diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index bf07e121e6..6a8e0a3917 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -470,6 +470,10 @@ def clear_tree_nodes(self): self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: assert not node.is_big_page_node() iter_node = node diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 21e26c5854..c103a61473 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -106,6 +106,7 @@ class RadixCache: def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager + self.total_token_num = total_token_num self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -419,6 +420,10 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index dfb8866601..5401189056 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -12,7 +12,7 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional +from typing import List from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue @@ -29,6 +29,10 @@ from lightllm.utils.profiler import ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -36,6 +40,7 @@ from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from .stats import RouterStatics from .profiler_service import RouterProfilerCmdQueue, start_router_profiler_server +from .rl_rpyc import RouterRlOpQueue, start_router_rl_rpyc_server logger = init_logger(__name__) @@ -104,6 +109,9 @@ def __init__(self, args: StartArgs): else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) self.router_statics = RouterStatics(self.args) + + # RL 管理 rpyc 队列:rpyc 线程把请求放进来,asyncio 主循环每个 step 取出处理 + self.rl_op_queue = RouterRlOpQueue() self.profiler_cmd_queue = RouterProfilerCmdQueue() return @@ -349,11 +357,42 @@ def _filter_reqs_from_running_batch(self): self.running_batch = None return + def _sync_abort_req_ids_from_master(self, reqs: List[Req]): + local_aborted_req_ids = [req.request_id for req in reqs if req.is_aborted] + if not self.is_multinode_tp: + return set(local_aborted_req_ids) + + # 多节点 TP 下 abort 只以 rank 0 httpserver 看到的状态为准。 + if not self.is_multinode_tp_master: + local_aborted_req_ids = [] + + aborted_req_num = torch.tensor([len(local_aborted_req_ids)], dtype=torch.int64, device="cpu") + dist.broadcast(aborted_req_num, src=0, group=self.mulitnode_group) + aborted_req_num = int(aborted_req_num.item()) + if aborted_req_num == 0: + return set() + + if self.is_multinode_tp_master: + aborted_req_ids = torch.tensor(local_aborted_req_ids, dtype=torch.int64, device="cpu") + else: + aborted_req_ids = torch.empty(aborted_req_num, dtype=torch.int64, device="cpu") + dist.broadcast(aborted_req_ids, src=0, group=self.mulitnode_group) + return {int(req_id) for req_id in aborted_req_ids.tolist()} + + def _mark_reqs_aborted(self, reqs: List[Req], aborted_req_ids): + if not aborted_req_ids: + return + for req in reqs: + if req.request_id in aborted_req_ids: + req.is_aborted = True + return + def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] - if self.running_batch is None: - return ans - for req in self.running_batch.reqs: + running_reqs = [] if self.running_batch is None else self.running_batch.reqs + aborted_req_ids = self._sync_abort_req_ids_from_master(running_reqs) + self._mark_reqs_aborted(running_reqs, aborted_req_ids) + for req in running_reqs: if req.is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) @@ -453,11 +492,19 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # select_mark 仍需 MIN allreduce: slave 是否已经从 httpserver 收到 generate 请求 dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_ids = self._sync_abort_req_ids_from_master(new_batch.reqs) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 0: - req = new_batch.pop_req(req_id) + for req_id, select in zip(req_ids, req_id_select_mark.tolist()): + is_aborted = req_id in aborted_req_ids + if select == 1 and not is_aborted: + continue + req = new_batch.pop_req(req_id) + if is_aborted and select == 1: + req.is_aborted = True + self.req_queue.release_aborted_req(req) + else: back_req_list.append(req) self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list if new_batch.is_clear(): @@ -471,25 +518,30 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_ids = self._sync_abort_req_ids_from_master([]) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - - for req in select_reqs: - self.req_queue.waiting_req_list.remove(req) + aborted_reqs = [] + for req_id, select in zip(req_ids, req_id_select_mark.tolist()): + if select == 1: + req = id_to_req_obj[req_id] + if req_id in aborted_req_ids: + req.is_aborted = True + aborted_reqs.append(req) + continue + select_reqs.append(req) + handled_req_ids = {req.request_id for req in select_reqs + aborted_reqs} + if handled_req_ids: + self.req_queue.waiting_req_list = [ + req for req in self.req_queue.waiting_req_list if req.request_id not in handled_req_ids + ] + for req in aborted_reqs: + self.req_queue.release_aborted_req(req) if select_reqs: new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node) else: @@ -514,7 +566,7 @@ async def _recv_new_reqs_and_schedule(self): if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -523,6 +575,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + await self._process_special_reqs() + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -530,6 +584,66 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + async def _process_special_reqs(self): + # master: 从 RL rpyc 队列里取出 (req, future) — slave 的队列恒为空(无 rpyc service) + pairs = self.rl_op_queue.pop_all() + + reqs: List[GeneralHttpToModelRpcReq] = [req for req, _ in pairs] + + # 多机 TP:master 通过 NCCL 广播 req 到 slave router;slave 在自己的主循环里到达此处时,会从 broadcast 收到 master 的 reqs + if self.is_multinode_tp: + reqs = self.broadcast_reqs_to_other_nodes(reqs) + + for i, req in enumerate(reqs): + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + try: + ret = await self.forward_to_model(req) + except BaseException as e: + logger.exception(f"forward_to_model failed for {req.func_name}: {e}") + ret = GeneralModelToHttpRpcRsp( + success=False, msg=f"forward_to_model error: {e}", func_name=req.func_name + ) + # 只有 master 持有 future,slave 的 pairs 始终为空 + if i < len(pairs): + _, fut = pairs[i] + if not fut.done(): + fut.set_result(ret) + + def broadcast_reqs_to_other_nodes(self, reqs: List[GeneralHttpToModelRpcReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + forward_to_model_tasks = [] + for model_rpc_client in self.model_rpc_clients: + forward_to_model_tasks.append(model_rpc_client.forward_to_model(req)) + all_ret = await asyncio.gather(*forward_to_model_tasks) + ret: GeneralModelToHttpRpcRsp = next((res for res in all_ret if not res.success), all_ret[0]) + ret.success = all(res.success for res in all_ret) + if self.is_multinode_tp: + output_list = [None for _ in range(self.nnodes)] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + if self.node_rank == 0: + for res in output_list: + res: GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + return ret + def clean_up(self): return @@ -557,6 +671,10 @@ def handle_exception(loop, context): args, router.profiler_cmd_queue, ) + router.rl_rpyc_server, router.rl_rpyc_thread = start_router_rl_rpyc_server( + args, + router.rl_op_queue, + ) except: import traceback import sys diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5c2d0d45fb..2f79c441bb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -23,6 +23,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import PDDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -121,6 +122,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if not (req.shm_req.finish_status.is_finished() or req.shm_req.stop_str_matched): + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -274,6 +285,8 @@ def _filter(self, finished_request_ids: List[int]): req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req) self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -849,6 +862,8 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif output_len >= self.sampling_param.shm_param.max_new_tokens: self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + elif self.infer_aborted: + self.finish_status.set_status(FinishStatus.FINISHED_ABORTED) return def _stop_sequences_matched(self, output_len: int): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a65dfb1bbb..25b5b7e1a4 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -33,8 +33,8 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.distributed import dist_group_manager from lightllm.distributed.communication_op import ( - dist_group_manager, all_gather_into_tensor, all_reduce, broadcast, @@ -597,7 +597,7 @@ def _get_classed_reqs( paused_reqs.append(req_obj) continue - if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + if req_obj.finish_status.is_finished(): if support_overlap: # 延迟处理 req_obj.filter_mark = True diff --git a/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py b/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py new file mode 100644 index 0000000000..aab832050f --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py @@ -0,0 +1,343 @@ +import gc +from typing import List, Optional + +import torch + +from lightllm.utils.dist_utils import init_custom_process_group +from lightllm.utils.rl.serialization import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.rl.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata +from lightllm.utils.rl.torch_cuda_ipc import cuda_rebuild_device_fallback, monkey_patch_torch_reductions +from lightllm.utils.torch_memory_saver_utils import MemoryTag +from lightllm.server.io_struct import ( + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromIPCReq, + UpdateWeightsFromTensorReq, +) + + +class RlBackendOps: + MEMORY_TAG_ORDER = (MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH) + + SUPPORTED = frozenset( + { + "flush_cache", + "release_memory_occupation", + "resume_memory_occupation", + "init_weights_update_group", + "destroy_weights_update_group", + "update_weights_from_distributed", + "update_weights_from_tensor", + "update_weights_from_ipc", + } + ) + + def __init__(self, backend) -> None: + self.backend = backend + self._model_update_group = {} + self._skip_tensor_updates_reason = None + self.logger = backend.logger + + @classmethod + def supports(cls, func_name: str) -> bool: + return func_name in cls.SUPPORTED + + def dispatch(self, func_name: str, func_args): + if not self.supports(func_name): + raise ValueError(f"RlBackendOps does not support function {func_name}") + return getattr(self, func_name)(func_args) + + def flush_cache(self, request: FlushCacheReq): + if self.backend.radix_cache is not None: + self.backend.radix_cache.flush_cache() + return True, "Succeeded to flush cache." + + def _iter_memory_tags(self, tags: Optional[List[MemoryTag]]): + return self.MEMORY_TAG_ORDER if tags is None else tags + + def _clear_cuda_cache(self): + torch.cuda.empty_cache() + gc.collect() + + def _pause_memory_tags(self, tags: Optional[List[MemoryTag]]): + torch.cuda.synchronize() + for tag in self._iter_memory_tags(tags): + self.backend.model.torch_memory_saver.pause(tag=tag) + self._clear_cuda_cache() + + def _resume_memory_tags(self, tags: Optional[List[MemoryTag]]): + self._clear_cuda_cache() + for tag in self._iter_memory_tags(tags): + self.backend.model.torch_memory_saver.resume(tag=tag) + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + try: + self._pause_memory_tags(tags) + self.flush_cache(request=None) + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + try: + self._resume_memory_tags(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" + + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank_offset = request.rank_offset + rank = rank_offset + self.backend.rank_in_dp + world_size = request.world_size + group_name = request.group_name + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " + f" backend={request.backend}" + ) + + try: + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") + + self._model_update_group[group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update model weights online through the custom weight update process group. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = {} + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights[name] = weight + for handle in handles: + handle.wait() + + self.backend.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + converted_metadata = [] + for meta in metadata: + if isinstance(meta, dict): + converted_meta = FlattenedTensorMetadata( + name=meta["name"], + shape=meta["shape"], + dtype=meta["dtype"], + start_idx=meta["start_idx"], + end_idx=meta["end_idx"], + numel=meta["numel"], + ) + else: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) + reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = {name: tensor for name, tensor in reconstructed_tensors} + loaded, skipped = self._load_compatible_named_tensors(named_tensors) + + return ( + True, + "Succeeded to update parameter online from flattened bucket tensor. " + f"loaded={loaded}, skipped={skipped}.", + ) + + @staticmethod + def _iter_named_tensors(named_tensors): + if isinstance(named_tensors, dict): + return named_tensors.items() + return named_tensors + + def _get_tensor_update_skip_reason(self, items): + if self._skip_tensor_updates_reason is not None: + return self._skip_tensor_updates_reason + + target_config = getattr(self.backend.model, "config", {}) or {} + target_is_moe = bool(target_config.get("num_experts") or target_config.get("moe_intermediate_size")) + source_has_moe_experts = any(".mlp.experts." in name for name, _ in items) + if source_has_moe_experts and not target_is_moe: + self._skip_tensor_updates_reason = "received MoE expert weights for a non-MoE backend" + self.logger.warning("skip tensor weight updates: %s", self._skip_tensor_updates_reason) + return self._skip_tensor_updates_reason + + return None + + def _load_compatible_named_tensors(self, named_tensors): + items = list(self._iter_named_tensors(named_tensors)) + skip_reason = self._get_tensor_update_skip_reason(items) + if skip_reason is not None: + return 0, len(items) + + def _load_range(weight_items): + if not weight_items: + return 0, 0 + weight_dict = dict(weight_items) + try: + self.backend.model.load_weights(weight_dict) + return len(weight_items), 0 + except Exception as e: + if len(weight_items) == 1: + name, tensor = weight_items[0] + self.logger.warning( + "skip incompatible tensor update %s shape=%s dtype=%s: %s", + name, + tuple(tensor.shape) if hasattr(tensor, "shape") else None, + getattr(tensor, "dtype", None), + e, + ) + return 0, 1 + + split_idx = len(weight_items) // 2 + left_loaded, left_skipped = _load_range(weight_items[:split_idx]) + right_loaded, right_skipped = _load_range(weight_items[split_idx:]) + return left_loaded + right_loaded, left_skipped + right_skipped + + return _load_range(items) + + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): + try: + monkey_patch_torch_reductions() + device_module = torch.get_device_module("cuda") + infered_device = device_module.current_device() + + if request.load_format == "flattened_bucket": + with cuda_rebuild_device_fallback(infered_device): + serialized_named_tensors = MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.backend.rank_in_dp] + ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + clone = tensor.to(device).clone() + del tensor + return clone + + with cuda_rebuild_device_fallback(infered_device): + named_tensors = MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.backend.rank_in_dp] + ) + named_tensors = { + name: _unwrap_tensor(tensor, tp_rank=self.backend.rank_in_dp, device=infered_device) + for name, tensor in self._iter_named_tensors(named_tensors) + } + + loaded, skipped = self._load_compatible_named_tensors(named_tensors) + + return True, f"Succeeded to update parameter online from tensor. loaded={loaded}, skipped={skipped}." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + + def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq): + try: + from lightllm.utils.rl.bucketed_weight_transfer import BucketedWeightReceiver, get_zmq_handle + + zmq_handle = request.ipc_handle + if isinstance(zmq_handle, dict): + zmq_handle = zmq_handle.get(self.backend.rank_in_node, zmq_handle.get(str(self.backend.rank_in_node))) + if zmq_handle is None: + raise ValueError(f"Missing ipc_handle for rank_in_node={self.backend.rank_in_node}") + if zmq_handle in (None, "", "auto"): + zmq_handle = get_zmq_handle() + use_shm = request.use_shm + recv_device = torch.device("cuda", self.backend.current_device_id) + self.logger.debug( + "[LightLLM] RlBackendOps.update_weights_from_ipc: request.ipc_handle=%r, " + "resolved zmq_handle=%r, cuda_device_id=%s", + request.ipc_handle, + zmq_handle, + self.backend.current_device_id, + ) + + bucketed_weight_receiver = BucketedWeightReceiver( + zmq_handle=zmq_handle, device=recv_device, use_shm=use_shm + ) + bucketed_weight_receiver.receive_weights(on_bucket_received=self.backend.model.load_weights) + return True, "Succeeded to update parameter online from ipc." + + except Exception as e: + import traceback + + traceback.print_exc() + message = f"Failed to update parameter online from ipc. Reason: {e}." + self.logger.error(message) + + return False, message diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 864a7405b7..1527b69d2b 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -28,11 +28,14 @@ PDDPForDecodeNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager +from lightllm.server.router.model_infer.mode_backend.rl_backend_ops import RlBackendOps from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -46,6 +49,8 @@ def __init__(self, args, rank: int, rank_in_node: int, node_world_size: int, inf self.rank = rank self.rank_in_node = rank_in_node + self.backend = None + self.rl_backend_ops = None logger.info(f"Initialized RPC server for rank {self.rank}.") return @@ -99,6 +104,7 @@ def exposed_init_model(self, kvargs): logger.info(f"use {self.backend.__class__.__name__}") self.backend.init_model(kvargs) + self.rl_backend_ops = RlBackendOps(self.backend) # only deepseekv3 can support auto_update_redundancy_expert if self.args.auto_update_redundancy_expert: @@ -111,6 +117,24 @@ def exposed_init_model(self, kvargs): def exposed_get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + req = obtain(req) + if self.rl_backend_ops is None: + raise ValueError("RL backend ops is not initialized") + if not RlBackendOps.supports(req.func_name): + raise ValueError( + f"Unsupported RL backend function {req.func_name}. " + f"Supported functions: {sorted(RlBackendOps.SUPPORTED)}" + ) + success, ret = self.rl_backend_ops.dispatch(req.func_name, req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) + class ModelRpcClient: def __init__(self, conn): @@ -134,6 +158,7 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.conn.root.init_model) self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num) + self._forward_to_model = async_wrap(self.conn.root.forward_to_model) return async def init_model(self, kvargs): @@ -145,6 +170,10 @@ async def get_max_total_token_num(self): ans = self._get_max_total_token_num() return obtain(await ans) + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + ans = self._forward_to_model(req) + return obtain(await ans) + def _init_env( args, @@ -197,7 +226,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 0d1ffe6967..5a31c77de6 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -4,6 +4,9 @@ from lightllm.server.core.objs import FinishStatus from lightllm.utils.config_utils import get_fixed_kv_len from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class BaseQueue: @@ -32,6 +35,30 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def should_release_aborted_req_in_queue(self, req: Req): + # 多节点 TP 的 waiting req abort 状态必须先由 rank 0 broadcast 对齐, + # 不能在各节点本地调度队列里提前按各自 shm 状态释放。 + return req.is_aborted and not self.router.is_multinode_tp + + def mark_aborted_req_finished(self, req: Req): + # 未开始推理的请求没有生成 token;这里写入一个 EOS 位置和 aborted 状态, + # 让 httpserver recycle loop 能正常结束请求并返回空字符串。 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + + def release_aborted_req(self, req: Req): + logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") + self.free_aborted_req_cpu_cache_pages(req) + self.mark_aborted_req_finished(req) + self.router.shm_req_manager.put_back_req_obj(req) + return + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index e82cc7e181..4a209ba7be 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -2,9 +2,6 @@ import numpy as np from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class ChunkedPrefillQueue(BaseQueue): @@ -73,21 +70,20 @@ def generate_new_batch(self, current_batch: Batch): self._init_cache_list(current_batch, is_busy) can_run_list = [] abort_req_list = [] - aborted_count = 0 + consumed_req_count = 0 waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. - # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 - aborted_count += 1 + if self.should_release_aborted_req_in_queue(req): + consumed_req_count += 1 abort_req_list.append(req) continue ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req( req, is_busy, new_batch_first_router_need_tokens ) if ok_insert: + consumed_req_count += 1 can_run_list.append(req) else: break @@ -95,11 +91,8 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: - req: Req = req - logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") - self.free_aborted_req_cpu_cache_pages(req) - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + self.release_aborted_req(req) + self.waiting_req_list = self.waiting_req_list[consumed_req_count:] return new_batch def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py index 5ec09f5760..11e4198fb2 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py @@ -3,9 +3,6 @@ from typing import Tuple from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class PDQueue(BaseQueue): @@ -69,21 +66,20 @@ def generate_new_batch(self, current_batch: Batch): can_run_list = [] abort_req_list = [] - aborted_count = 0 + consumed_req_count = 0 waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. - # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 - aborted_count += 1 + if self.should_release_aborted_req_in_queue(req): + consumed_req_count += 1 abort_req_list.append(req) continue ok_insert, estimated_peak_token_num, batch_req_num = self._can_add_new_req( req=req, estimated_peak_token_num=estimated_peak_token_num, batch_req_num=batch_req_num ) if ok_insert: + consumed_req_count += 1 can_run_list.append(req) else: break @@ -91,11 +87,8 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: - req: Req = req - logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") - self.free_aborted_req_cpu_cache_pages(req) - self.router.shm_req_manager.put_back_req_obj(req) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + self.release_aborted_req(req) + self.waiting_req_list = self.waiting_req_list[consumed_req_count:] return new_batch def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 866e1b9f42..af8f875d4e 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -26,6 +26,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def release_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].release_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/router/rl_rpyc.py b/lightllm/server/router/rl_rpyc.py new file mode 100644 index 0000000000..86a32df0a2 --- /dev/null +++ b/lightllm/server/router/rl_rpyc.py @@ -0,0 +1,66 @@ +import concurrent.futures +import queue +import threading +from typing import List, Tuple + +import rpyc +from rpyc.utils.classic import obtain + +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class RouterRlOpQueue: + def __init__(self): + self._queue: "queue.Queue[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]]" = queue.Queue() + + def submit(self, req: GeneralHttpToModelRpcReq, timeout: float = 300.0) -> GeneralModelToHttpRpcRsp: + fut: concurrent.futures.Future = concurrent.futures.Future() + self._queue.put((req, fut)) + try: + return fut.result(timeout=timeout) + except concurrent.futures.TimeoutError: + return GeneralModelToHttpRpcRsp( + success=False, + msg=f"rl op {req.func_name} timeout after {timeout}s", + func_name=req.func_name, + ) + + def pop_all(self) -> List[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]]: + pairs = [] + while True: + try: + pairs.append(self._queue.get_nowait()) + except queue.Empty: + break + return pairs + + +class RouterRlRpcService(rpyc.Service): + def __init__(self, rl_op_queue: RouterRlOpQueue): + super().__init__() + self.rl_op_queue = rl_op_queue + + def exposed_rl_op(self, req: GeneralHttpToModelRpcReq): + return self.rl_op_queue.submit(obtain(req)) + + +def start_router_rl_rpyc_server(args, rl_op_queue: RouterRlOpQueue): + if args.node_rank != 0 or args.rl_rpyc_port is None: + return None, None + + from rpyc.utils.server import ThreadedServer + import lightllm.utils.rpyc_fix_utils as _ + + server = ThreadedServer( + RouterRlRpcService(rl_op_queue), + hostname="127.0.0.1", + port=args.rl_rpyc_port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + thread = threading.Thread(target=server.start, name="rl_rpyc_server", daemon=True) + thread.start() + logger.info(f"router rl rpyc server started on port {args.rl_rpyc_port}") + return server, thread diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index 3e74793634..b08e2d13e2 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -11,6 +11,7 @@ from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name +from lightllm.utils.process_check import start_parent_check_thread from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer from ..objs import rpyc_config @@ -20,6 +21,7 @@ def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 5b9705ed0e..87aae86e4f 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -80,12 +80,15 @@ def init_vision_distributed_env(kvargs): device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -150,7 +153,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -316,3 +318,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index b87096d945..def49a2fa5 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,14 +2,28 @@ import subprocess import ipaddress import random +import os from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +DEFAULT_BASE_PORT = 10000 +PORTS_PER_INSTANCE = 1000 +MAX_INSTANCE_ID = 7 + + +def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=DEFAULT_BASE_PORT, instance_id=0): + if instance_id < 0 or instance_id > MAX_INSTANCE_ID: + raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}") + + base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", from_port_num)) + # Keep independent launchers away from the same free-port window, especially for NCCL TCPStore ports. + range_start = base_port + instance_id * PORTS_PER_INSTANCE + range_end = range_start + PORTS_PER_INSTANCE + used_ports = used_ports or [] -def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000): port_list = [] - for port in range(from_port_num, 65536): + for port in range(range_start, range_end): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: result = s.connect_ex(("localhost", port)) if result != 0 and port not in used_ports: diff --git a/lightllm/utils/rl/__init__.py b/lightllm/utils/rl/__init__.py new file mode 100644 index 0000000000..12d5ab89e5 --- /dev/null +++ b/lightllm/utils/rl/__init__.py @@ -0,0 +1,6 @@ +""" +Utilities used by RL weight-update and colocated rollout integration paths. + +This package groups the CUDA IPC serializer, tensor bucketing helpers, and +bucketed weight-transfer protocol used by the RL endpoints. +""" diff --git a/lightllm/utils/rl/bucketed_weight_transfer.py b/lightllm/utils/rl/bucketed_weight_transfer.py new file mode 100644 index 0000000000..4849497f1a --- /dev/null +++ b/lightllm/utils/rl/bucketed_weight_transfer.py @@ -0,0 +1,302 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Bucketed RL weight transfer via ZMQ plus CUDA IPC or shared memory fallback. + +This module builds on torch_cuda_ipc.py for CUDA IPC device handling. It owns +the higher-level transfer protocol: the sender publishes one reusable +communication buffer, and the receiver rebuilds that buffer on its target +device before applying each metadata-described bucket to model weights. + +Copied from: +https://github.com/verl-project/verl/blob/main/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py +""" + +import gc +from multiprocessing import shared_memory +from typing import TypedDict + +import torch +import zmq +from torch.multiprocessing.reductions import reduce_tensor +from lightllm.utils.rl.torch_cuda_ipc import ( + cuda_device_to_uuid, + get_current_device_id, + get_current_device_name, + rebuild_cuda_ipc_tensor, +) + + +def get_zmq_handle() -> str: + return f"ipc:///tmp/rl-colocate-zmq-{cuda_device_to_uuid(get_current_device_id())}.sock" + + +class TensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +def create_shared_memory(size: int, name: str): + """Create shared memory for weight transfer. If already exists, attach to it.""" + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=size) + except FileExistsError: + shm = shared_memory.SharedMemory(name=name) + assert shm.size >= size, f"Stale shm segment '{name}': expected {size} bytes, got {shm.size}" + return shm + + +def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8): + """Rebuild tensor from shared memory.""" + shm = shared_memory.SharedMemory(name=name) + tensor = torch.frombuffer(shm.buf[:size], dtype=dtype) + + return tensor, shm + + +class BucketedWeightSender: + """ + Send model weights via bucketed IPC transfer over ZMQ. + + Packs weight tensors into a fixed-size communication buffer and sends them + in buckets to the receiver. Supports CUDA IPC and shared memory fallback. + + Args: + zmq_handle: ZMQ IPC socket path (e.g., "ipc:///tmp/rl-colocate-zmq-.sock") + bucket_size_mb: Communication buffer size in MB + use_shm: Use shared memory instead of CUDA IPC (for NPU compatibility) + """ + + def __init__( + self, + zmq_handle: str, + bucket_size_mb: int = 512, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.bucket_size_mb = bucket_size_mb + self.bucket_size = int(bucket_size_mb) << 20 + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + async def async_send_weights(self, weights): + """ + Send weights to the receiver. Accepts a sync generator or async iterator. + + Args: + weights: Generator or async iterator yielding (name, tensor) pairs + """ + from verl.workers.rollout.utils import ensure_async_iterator + + try: + self._init_socket() + self._init_buffer() + + # send bucket weights + offset = 0 + bucket_meta: dict[str, TensorMetadata] = {} + # dtype = PrecisionType.to_dtype(self.config.dtype) + async for name, weight in ensure_async_iterator(weights): + # model parameters are in fp32 full precision + # (vermouth1992) we should not force cast weight here because some parameters + # (such as moe gate) have to keep fp32 precision. If a weight is bf16 in the rollout side, + # the rollout should automatically cast on demand. However, this would incur a higher weight + # transfer volume. + # weight = weight.to(dtype, non_blocking=True) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + torch.cuda.synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + self.socket.recv() + bucket_meta = {} + offset = 0 + + # TODO: slice embedding layer weight into chunks + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)." + ) + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + self.buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + # send the last bucket + torch.cuda.synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) + self.socket.recv() + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REQ socket and bind.""" + self.socket = self.zmq_context.socket(zmq.REQ) + self.socket.bind(self.zmq_handle) + + def _init_buffer(self): + """build communication buffer""" + buffer, shm = None, None + if not self.use_shm: + buffer = torch.empty( + self.bucket_size, + dtype=torch.uint8, + device=f"{get_current_device_name()}:{get_current_device_id()}", + ) + handle = reduce_tensor(buffer) + self.socket.send_pyobj(handle) + else: + import uuid + + # Create unique name for shared memory + shm_name = f"verl_weights_{uuid.uuid4().hex}" + shm = create_shared_memory(self.bucket_size, shm_name) + buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) + + comm_metadata = {"name": shm_name, "size": self.bucket_size} + self.socket.send_pyobj(comm_metadata) + + self.socket.recv() + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + self.shm.unlink() + del self.shm + self.shm = None + gc.collect() + torch.cuda.ipc_collect() + torch.cuda.empty_cache() + + +class BucketedWeightReceiver: + """ + Receive model weights via bucketed IPC transfer over ZMQ. + + Receives weight tensors from BucketedWeightSender and passes each + bucket to a callback for processing (e.g., loading into the model). + + Args: + zmq_handle: ZMQ IPC socket path (must match sender) + device: Target device for received tensors + use_shm: Use shared memory instead of CUDA IPC + """ + + def __init__( + self, + zmq_handle: str, + device: torch.device, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.device = device + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + def receive_weights(self, on_bucket_received: callable): + """ + Receive weights from sender and process each bucket via callback. + + Args: + on_bucket_received: Callback function(weight_dict: dict[str, torch.Tensor]) called per bucket. + """ + try: + self._init_socket() + self._init_buffer() + + # receive bucket and update weights + while True: + metadata = self.socket.recv_pyobj() + weights, tensor = [], None + for name, meta in metadata["bucket_meta"].items(): + shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] + size = dtype.itemsize * shape.numel() + # NOTE: we need to clone the tensor to release CUDA IPC memory + # but for shared memory, it's not necessary and if we do clone, + # it will cause extra memory copy overhead and slow down the process. + tensor = self.buffer[offset : offset + size].view(dtype=dtype).view(shape) + if not self.use_shm: + tensor = tensor.clone() + else: + tensor = tensor.to(self.device) + weights.append((name, tensor)) + torch.cuda.synchronize() + self.socket.send(b"") + on_bucket_received(dict(weights)) + del weights, tensor + if metadata["is_last"]: + break + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REP socket and connect.""" + self.socket = self.zmq_context.socket(zmq.REP) + self.socket.connect(self.zmq_handle) + + def _init_buffer(self): + """Receive and rebuild communication buffer from sender.""" + comm_metadata = self.socket.recv_pyobj() + buffer, shm = None, None + if not self.use_shm: + handle = comm_metadata + buffer = rebuild_cuda_ipc_tensor(handle, self.device.index) + assert buffer.dtype == torch.uint8 + else: + shm_name = comm_metadata["name"] + shm_size = comm_metadata["size"] + buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8) + self.socket.send(b"") + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + # Synchronize before releasing the buffer to ensure all async ops + # referencing it (e.g. clone, .to()) have completed. + torch.cuda.synchronize() + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + del self.shm + self.shm = None + gc.collect() + torch.cuda.ipc_collect() + torch.cuda.empty_cache() diff --git a/lightllm/utils/rl/serialization.py b/lightllm/utils/rl/serialization.py new file mode 100644 index 0000000000..e5be3f0968 --- /dev/null +++ b/lightllm/utils/rl/serialization.py @@ -0,0 +1,140 @@ +""" +Serialization helpers for RL weight-update requests. + +The tensor update endpoint passes CUDA tensors across processes by serializing +their multiprocessing IPC handles rather than copying tensor data into the HTTP +payload. This module wraps ForkingPickler for that handoff and uses a guarded +unpickler because the payload may come from a trainer process. + +Copied from: +https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py +""" +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/rl/tensor_bucket.py b/lightllm/utils/rl/tensor_bucket.py new file mode 100644 index 0000000000..72defe9859 --- /dev/null +++ b/lightllm/utils/rl/tensor_bucket.py @@ -0,0 +1,111 @@ +""" +Flattened tensor buckets used by RL weight-update requests. + +Some trainer-side integrations send many named tensors as one flattened byte +tensor plus metadata. The server reconstructs the named tensors before passing +them to model weight loading. + +Copied from: +https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py +""" +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/lightllm/utils/rl/torch_cuda_ipc.py b/lightllm/utils/rl/torch_cuda_ipc.py new file mode 100644 index 0000000000..f2158c7a43 --- /dev/null +++ b/lightllm/utils/rl/torch_cuda_ipc.py @@ -0,0 +1,119 @@ +""" +Patch torch multiprocessing CUDA tensor reductions to make cross-process +CUDA IPC robust when different processes have different CUDA_VISIBLE_DEVICES. + +Torch serializes CUDA tensor IPC handles with a device index. That index is +local to each process, so sender cuda:0 and receiver cuda:0 may refer to +different physical GPUs. We replace the serialized device index with the GPU +UUID on send, then map that UUID back to the receiver's local device index +while rebuilding the tensor. + +The patch wraps torch's original reducers and only changes the device argument, +so it avoids copying torch's CUDA IPC serialization implementation. + +Copied from: +https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +""" +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Callable, Union + +import torch +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The torch CUDA IPC rebuild signature has kept the device argument at this +# index for years. Keep this constant in one place because both the global +# monkey patch and local bucketed IPC rebuild path need to rewrite it. +CUDA_IPC_REBUILD_DEVICE_ARG_INDEX = 6 +_rebuild_device_fallback: ContextVar[Union[int, None]] = ContextVar("rebuild_device_fallback", default=None) + + +@contextmanager +def cuda_rebuild_device_fallback(device: Union[int, None]): + token = _rebuild_device_fallback.set(device) + try: + yield + finally: + _rebuild_device_fallback.reset(token) + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, CUDA_IPC_REBUILD_DEVICE_ARG_INDEX, cuda_device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, CUDA_IPC_REBUILD_DEVICE_ARG_INDEX, cuda_device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def get_current_device_name() -> str: + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def get_current_device_module(): + device_name = get_current_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + return torch.cuda + + +def get_current_device_id() -> int: + return get_current_device_module().current_device() + + +def cuda_device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def cuda_device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + fallback_device = _rebuild_device_fallback.get() + if fallback_device is not None: + return fallback_device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def rebuild_cuda_ipc_tensor(handle: tuple[Callable, tuple], device_id: Union[int, None] = None) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + list_args[CUDA_IPC_REBUILD_DEVICE_ARG_INDEX] = device_id + return func(*list_args) + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 0000000000..90d1c9e6b1 --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,94 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess as tms_configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weights" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + @contextmanager + def configure_subprocess(self): + with tms_configure_subprocess(): + yield + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index 603d0e488f..180f9d3ced 100644 --- a/requirements.txt +++ b/requirements.txt @@ -98,3 +98,4 @@ nixl==1.2.0 xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 +torch_memory_saver==0.0.9.post1 \ No newline at end of file diff --git a/test/test_api/test_abort_chaos.py b/test/test_api/test_abort_chaos.py new file mode 100644 index 0000000000..63f226717e --- /dev/null +++ b/test/test_api/test_abort_chaos.py @@ -0,0 +1,105 @@ +""" +Two-stage abort test against a running lightllm server. + +Stage 1: spawn N concurrent streams, then post /abort_request abort_all=True; + verify every stream terminates quickly. +Stage 2: spawn N concurrent streams; each stream is independently assigned a + random fate (disconnect mid-stream or run to completion). The server + must keep serving the survivors and stay healthy afterwards. + +Usage: + python test/test_api/test_abort_chaos.py --url http://127.0.0.1:8000 +""" + +import argparse +import asyncio +import json +import random +import time +from collections import Counter + +import httpx + + +PROMPTS = [ + "Write a long detailed essay about the history of computing.", + "Tell me a long story about dragons and knights.", + "Explain quantum mechanics in detail with lots of examples.", + "Describe the plot of a 5-part fantasy novel series.", + "Compose a long poem about the seasons.", +] + + +async def stream_task(client, url, mode, max_new_tokens): + payload = { + "inputs": random.choice(PROMPTS), + "parameters": {"max_new_tokens": max_new_tokens, "temperature": 0.7, "do_sample": True}, + } + drop_after = random.randint(20, 200) + tokens = 0 + finish_reason = None + t0 = time.time() + try: + async with client.stream("POST", f"{url}/generate_stream", json=payload, timeout=180.0) as r: + async for line in r.aiter_lines(): + tokens += 1 + if line.startswith("data:"): + chunk = json.loads(line[len("data:") :]) + if chunk.get("finished"): + finish_reason = chunk.get("finish_reason") + if mode == "disconnect" and tokens >= drop_after: + break + return (mode, finish_reason or "ok", tokens, time.time() - t0) + except Exception as e: + return (mode, f"exc:{type(e).__name__}", tokens, time.time() - t0) + + +def summarize(results): + outcomes = Counter() + for r in results: + outcomes[(r[0], r[1]) if isinstance(r, tuple) else f"raised:{type(r).__name__}"] += 1 + for k, v in sorted(outcomes.items(), key=str): + print(f" {k}: {v}") + + +async def stage_abort_all(client, url, concurrency, max_new_tokens): + print("\n===== STAGE 1: abort_all on N concurrent streams =====") + tasks = [asyncio.create_task(stream_task(client, url, "finish", max_new_tokens)) for _ in range(concurrency)] + await asyncio.sleep(2.0) + t0 = time.time() + r = await client.post(f"{url}/abort_request", json={"abort_all": True}, timeout=10.0) + print(f"abort_all status={r.status_code}") + results = await asyncio.gather(*tasks, return_exceptions=True) + print(f"all streams settled in {time.time() - t0:.2f}s") + summarize(results) + + +async def stage_random_chaos(client, url, concurrency, max_new_tokens): + print("\n===== STAGE 2: random per-stream chaos =====") + modes = random.choices(["disconnect", "finish"], weights=[80, 20], k=concurrency) + tasks = [asyncio.create_task(stream_task(client, url, m, max_new_tokens)) for m in modes] + results = await asyncio.gather(*tasks, return_exceptions=True) + summarize(results) + + +async def run(url, concurrency, max_new_tokens): + async with httpx.AsyncClient() as client: + await stage_abort_all(client, url, concurrency, max_new_tokens) + await stage_random_chaos(client, url, concurrency, max_new_tokens) + print("\nALL CHAOS TESTS PASSED") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--url", default="http://127.0.0.1:8000") + parser.add_argument("--concurrency", type=int, default=24) + parser.add_argument("--max_new_tokens", type=int, default=2048) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + random.seed(args.seed) + asyncio.run(run(args.url, args.concurrency, args.max_new_tokens)) + + +if __name__ == "__main__": + main() diff --git a/test/test_api/test_abort_request.py b/test/test_api/test_abort_request.py new file mode 100644 index 0000000000..ca99f6298c --- /dev/null +++ b/test/test_api/test_abort_request.py @@ -0,0 +1,432 @@ +""" +Test the /abort_request endpoint against a running lightllm server. + +What this test asserts (and why it does not assert "stream becomes finish_reason='abort'"): + + In normal / chunked_prefill mode, /abort_request: + - sets shm_req.is_aborted = True + - drives the router to send AbortedReqCmd, which sets InferReq.infer_aborted = True + on the worker + - causes still-waiting (not yet scheduled) reqs to be freed with FINISHED_ABORTED + - but does NOT cause already-running reqs to early-exit; they finish at max_new_tokens + / EOS / stop sequence as usual. (The shm flag is consumed by audio/visual servers + and pd_nixl mode, but the LLM inference loop never short-circuits on it.) + +So the test verifies the contract that actually exists today: + + Stage A: bogus request_id -> HTTP 200, server log "not exist" warning + Stage B: abort_all on an idle server -> HTTP 200, no errors + Stage C: abort_all on a running stream + -> HTTP 200; server log shows "aborted group_request_id N" warning + -> the stream terminates within reasonable time (whether via abort or + natural max_new_tokens completion) + Stage D: abort by SPECIFIC request_id on a running stream + -> resolve the lightllm_req_id from the server log (via X-Request-Id), + POST /abort_request with that exact id, verify the targeted log + warning lands and the stream terminates + Stage E: server remains healthy and answers a fresh /generate + +Usage: + python test/test_api/test_abort_request.py \ + --url http://127.0.0.1:8000 \ + --server_log_path /tmp/lightllm_test/server.log +""" + +import argparse +import json +import os +import re +import sys +import threading +import time +import uuid +from typing import List, Optional, Tuple + +import requests + + +GREEN = "\033[32m" +RED = "\033[31m" +YELLOW = "\033[33m" +RESET = "\033[0m" + + +def banner(msg: str): + print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True) + + +def ok(msg: str): + print(f" {GREEN}OK{RESET} {msg}", flush=True) + + +def fail(msg: str): + print(f" {RED}FAIL{RESET} {msg}", flush=True) + + +# ---------------- HTTP helpers ---------------- + + +def _get_health(url: str, timeout=5): + return requests.get(url + "/health", timeout=timeout) + + +def post_abort(url: str, request_id: Optional[int] = None, abort_all: bool = False) -> Tuple[int, str]: + payload = {"abort_all": abort_all} + if request_id is not None: + payload["request_id"] = request_id + r = requests.post(url + "/abort_request", json=payload, timeout=30) + return r.status_code, r.text + + +# ---------------- streaming helpers ---------------- + + +def _stream_run( + url: str, + prompt: str, + max_new_tokens: int, + x_request_id: str, + out: dict, + close_after_n: Optional[int] = None, +): + """ + Issue a /generate_stream and append every event to out["events"]. + If close_after_n is set, the underlying socket is forcibly closed + (TCP RST via SO_LINGER + close) after that many events arrive — kept + here for completeness even though no current stage uses it. Sets + out["error"] on transport errors. + """ + headers = {"X-Request-Id": x_request_id, "Content-Type": "application/json"} + body = { + "inputs": prompt, + "parameters": { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "ignore_eos": True, + }, + } + out["events"] = [] + out["start"] = time.time() + out["error"] = None + out["closed_intentionally"] = False + try: + # urllib3 keeps the socket pooled; we need direct access to force-close. + with requests.post(url + "/generate_stream", json=body, headers=headers, stream=True, timeout=120) as r: + r.raise_for_status() + for raw in r.iter_lines(decode_unicode=True): + if not raw: + continue + if raw.startswith("data:"): + raw = raw[len("data:") :] + try: + ev = json.loads(raw) + except Exception: + continue + ev["_t"] = time.time() - out["start"] + out["events"].append(ev) + if close_after_n is not None and len(out["events"]) >= close_after_n: + out["closed_intentionally"] = True + # Reach into urllib3 to force a TCP RST so the server sees + # the disconnect immediately rather than after a graceful + # FIN that hypercorn might not propagate while the response + # is mid-stream. + try: + import socket as _socket + + sock = r.raw._fp.fp.raw._sock # type: ignore[attr-defined] + # SO_LINGER with timeout 0 -> RST on close. + l_onoff, l_linger = 1, 0 + sock.setsockopt( + _socket.SOL_SOCKET, + _socket.SO_LINGER, + int.to_bytes(l_onoff, 4, "little") + int.to_bytes(l_linger, 4, "little"), + ) + sock.close() + except Exception as e: + out["close_error"] = repr(e) + break + if ev.get("finished"): + break + except Exception as e: + out["error"] = repr(e) + out["end"] = time.time() + + +def start_stream( + url: str, prompt: str, max_new_tokens: int, close_after_n: Optional[int] = None +) -> Tuple[threading.Thread, dict, str]: + xid = uuid.uuid4().hex + out = {} + th = threading.Thread(target=_stream_run, args=(url, prompt, max_new_tokens, xid, out, close_after_n)) + th.daemon = True + th.start() + return th, out, xid + + +def wait_for_first_token(out: dict, timeout: float = 30.0) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if out.get("events"): + return True + time.sleep(0.05) + return False + + +def get_finish_reason(out: dict) -> Optional[str]: + for ev in reversed(out.get("events") or []): + fr = ev.get("finish_reason") + if fr: + return fr + return None + + +# ---------------- log helpers ---------------- + + +def _read_log_tail(server_log_path: Optional[str], max_bytes: int = 256 * 1024) -> str: + if not server_log_path or not os.path.exists(server_log_path): + return "" + try: + size = os.path.getsize(server_log_path) + with open(server_log_path, "rb") as f: + if size > max_bytes: + f.seek(size - max_bytes) + return f.read().decode("utf-8", errors="ignore") + except FileNotFoundError: + return "" + + +def grep_log_for_pattern(server_log_path: Optional[str], pattern: re.Pattern, timeout: float = 5.0) -> Optional[str]: + """Poll the tail of the server log for a regex match.""" + deadline = time.time() + timeout + while time.time() < deadline: + tail = _read_log_tail(server_log_path) + m = pattern.search(tail) + if m: + return m.group(0) + time.sleep(0.1) + return None + + +def grep_log_after_offset( + server_log_path: Optional[str], start_offset: int, pattern: re.Pattern, timeout: float = 5.0 +) -> Optional[str]: + """Poll the server log starting at start_offset for a regex match. + Only content written after start_offset is considered, so this isolates + a stage from log produced by earlier stages.""" + if not server_log_path: + return None + deadline = time.time() + timeout + while time.time() < deadline: + try: + with open(server_log_path, "rb") as f: + f.seek(start_offset) + new = f.read().decode("utf-8", errors="ignore") + except FileNotFoundError: + new = "" + m = pattern.search(new) + if m: + return m.group(0) + time.sleep(0.1) + return None + + +def server_log_size(server_log_path: Optional[str]) -> int: + if not server_log_path or not os.path.exists(server_log_path): + return 0 + return os.path.getsize(server_log_path) + + +def lookup_lightllm_req_id_from_log(server_log_path: str, x_request_id: str, timeout: float = 5.0) -> Optional[int]: + pattern = re.compile(rf"received req X-Request-Id:{re.escape(x_request_id)}\b.*?lightllm_req_id:(\d+)") + deadline = time.time() + timeout + while time.time() < deadline: + tail = _read_log_tail(server_log_path) + m = pattern.search(tail) + if m: + return int(m.group(1)) + time.sleep(0.1) + return None + + +# ---------------- stages ---------------- + + +def stage_a_bogus_id(url: str) -> bool: + banner("Stage A: abort with a non-existent id") + bogus = 99_999_999 + code, text = post_abort(url, request_id=bogus, abort_all=False) + print(f" /abort_request request_id={bogus} -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + ok("HTTP 200") + return True + + +def stage_b_abort_all_idle(url: str) -> bool: + banner("Stage B: abort_all on an idle server") + code, text = post_abort(url, abort_all=True) + print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + ok("HTTP 200") + return True + + +def stage_c_abort_running(url: str, server_log_path: Optional[str]) -> bool: + banner("Stage C: abort_all on a running stream") + log_offset = server_log_size(server_log_path) + th, out, xid = start_stream(url, "Recite the alphabet repeatedly.", max_new_tokens=200) + if not wait_for_first_token(out, timeout=30.0): + fail("did not receive any tokens before abort") + return False + first_t = out["events"][0]["_t"] + ok(f"first token at +{first_t:.2f}s") + + target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0) if server_log_path else None + print(f" resolved lightllm_req_id from log: {target_id}") + + code, text = post_abort(url, abort_all=True) + print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + + th.join(timeout=60.0) + if th.is_alive(): + fail("stream did not terminate within 60s of abort") + return False + fr = get_finish_reason(out) + n = len(out.get("events") or []) + print(f" stream events received: {n}, finish_reason={fr!r}, error={out.get('error')!r}") + + # The api itself succeeded; whether the stream got a clean 'abort' finish reason + # depends on which mode-backend the server is running. We DO assert the abort + # warning landed in the server log though, scoped to log content produced after + # this stage started so we don't match earlier-stage residue. + if server_log_path: + if target_id is not None: + pat = re.compile(rf"aborted group_request_id {target_id}\b") + else: + pat = re.compile(r"aborted group_request_id \d+") + hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0) + if not hit: + fail("could not find 'aborted group_request_id' in server log (post-stage)") + return False + ok(f"server log recorded: {hit!r}") + else: + print(" no --server_log_path; skipped log assertion") + ok("stream terminated and abort acknowledged") + return True + + +def stage_d_abort_by_id(url: str, server_log_path: Optional[str]) -> bool: + banner("Stage D: abort by specific request_id on a running stream") + if not server_log_path: + print(" --server_log_path not provided; skipping (we need the log to resolve req_id)") + return True + + log_offset = server_log_size(server_log_path) + th, out, xid = start_stream(url, "Sing a long lullaby for the moon.", max_new_tokens=300) + if not wait_for_first_token(out, timeout=30.0): + fail("did not receive any tokens before abort") + return False + ok(f"first token at +{out['events'][0]['_t']:.2f}s, X-Request-Id={xid[:8]}…") + + target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0) + if target_id is None: + fail("could not resolve lightllm_req_id from server log; cannot test by-id abort") + return False + print(f" resolved lightllm_req_id: {target_id}") + + code, text = post_abort(url, request_id=target_id, abort_all=False) + print(f" /abort_request request_id={target_id} -> HTTP {code} body={text!r}") + if code != 200: + fail(f"expected HTTP 200, got {code}") + return False + + th.join(timeout=60.0) + if th.is_alive(): + fail("stream did not terminate within 60s") + return False + fr = get_finish_reason(out) + n = len(out.get("events") or []) + print(f" stream events received: {n}, finish_reason={fr!r}") + + pat = re.compile(rf"aborted group_request_id {target_id}\b") + hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0) + if not hit: + fail(f"could not find 'aborted group_request_id {target_id}' in server log (post-stage)") + return False + ok(f"server log recorded: {hit!r}") + return True + + +def stage_e_health_after(url: str) -> bool: + banner("Stage E: server still serves a normal /generate") + r = requests.post( + url + "/generate", + json={ + "inputs": "The capital of France is", + "parameters": {"max_new_tokens": 6, "do_sample": False}, + }, + timeout=60, + ) + print(f" /generate -> HTTP {r.status_code} {r.text[:200]}") + if r.status_code != 200: + fail(f"final /generate failed with {r.status_code}") + return False + body = r.json() + text = body.get("generated_text") + if isinstance(text, list): + text = text[0] + if not text or not text.strip(): + fail("final /generate returned empty text") + return False + ok(f"final /generate returned {text!r}") + return True + + +# ---------------- main ---------------- + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--url", default="http://127.0.0.1:8000") + ap.add_argument( + "--server_log_path", + default=None, + help="optional path to the server stdout/stderr log; enables log-grep assertions", + ) + args = ap.parse_args() + + try: + r = _get_health(args.url) + r.raise_for_status() + except Exception as e: + fail(f"server at {args.url} not reachable: {e}") + sys.exit(1) + ok(f"server reachable at {args.url}") + + results = [] + results.append(("A", stage_a_bogus_id(args.url))) + results.append(("B", stage_b_abort_all_idle(args.url))) + results.append(("C", stage_c_abort_running(args.url, args.server_log_path))) + results.append(("D", stage_d_abort_by_id(args.url, args.server_log_path))) + results.append(("E", stage_e_health_after(args.url))) + + print("\n" + "=" * 50) + all_ok = True + for name, passed in results: + tag = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}" + print(f" Stage {name}: {tag}") + all_ok = all_ok and passed + if not all_ok: + sys.exit(1) + print(f"\n{GREEN}ALL ABORT STAGES PASSED{RESET}") + + +if __name__ == "__main__": + main() diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 0000000000..85c4e44ef9 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,92 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[0] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/test/test_api/test_rl_endpoints.py b/test/test_api/test_rl_endpoints.py new file mode 100644 index 0000000000..91829530ec --- /dev/null +++ b/test/test_api/test_rl_endpoints.py @@ -0,0 +1,349 @@ +""" +Test release_memory_occupation / resume_memory_occupation / update_weights_from_tensor +against a running lightllm server. + +Sequence: + 1. baseline generate (sanity) + 2. release_memory_occupation -> GPU memory should drop sharply + 3. resume_memory_occupation -> GPU memory should grow back + (without --enable_weight_cpu_backup the weight + memory is allocated empty, so generation right + after resume is expected to be garbage) + 4. update_weights_from_tensor (per-batch CUDA-IPC handoff) for every parameter + found on disk -> repopulate weights + 5. final generate -> should produce a sensible answer again + +The "trainer" runs in this same process: it holds tensors on a free GPU, serialises +them via lightllm.utils.rl.serialization.MultiprocessingSerializer (CUDA IPC handles, not +data), then asks the server to clone them into its weight buffers. No NCCL group +is required, so this is safe to interrupt without leaving the server hung. + +Usage: + python test/test_api/test_rl_endpoints.py \ + --url http://127.0.0.1:8000 \ + --model_dir /nvme/models/Qwen3.5-35B-A3B \ + --tp 4 \ + --server_devices 0,1,2,3 \ + --client_device 4 + +Notes: + - This script must run on the same machine as the server (CUDA IPC). + - --server_devices are nvidia-smi GPU indices for the TP workers. If omitted, + the script infers them from the top --tp memory consumers before release. + - --client_device picks a free CUDA device for the in-process trainer; it is + independent from --server_devices and should not overlap the TP workers. +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from glob import glob +from typing import Dict, List, Tuple + +# Make the repo importable when this script is invoked by path rather than -m. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import requests +import torch +from safetensors import safe_open + +from lightllm.utils.rl.serialization import MultiprocessingSerializer +from lightllm.utils.rl.torch_cuda_ipc import monkey_patch_torch_reductions + + +GREEN = "\033[32m" +RED = "\033[31m" +YELLOW = "\033[33m" +RESET = "\033[0m" + + +def banner(msg: str): + print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True) + + +def ok(msg: str): + print(f" {GREEN}OK{RESET} {msg}", flush=True) + + +def fail(msg: str): + print(f" {RED}FAIL{RESET} {msg}", flush=True) + + +def gpu_mem_used_mib() -> List[int]: + out = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"]).decode() + return [int(x.strip()) for x in out.strip().splitlines()] + + +def _select_gpu_mem(mem: List[int], devices: List[int]) -> List[int]: + return [mem[i] for i in devices] + + +def _resolve_server_devices(server_devices: str, tp: int, mem: List[int]) -> List[int]: + if tp <= 0: + raise ValueError(f"--tp must be positive, got {tp}") + if tp > len(mem): + raise ValueError(f"--tp={tp} but nvidia-smi only returned {len(mem)} GPUs") + + value = server_devices.strip() + if value.lower() == "auto": + return sorted(range(len(mem)), key=lambda i: mem[i], reverse=True)[:tp] + + devices = [int(x.strip()) for x in value.split(",") if x.strip()] + if len(devices) != tp: + raise ValueError(f"--server_devices must contain exactly --tp entries; got {devices} for tp={tp}") + if len(set(devices)) != len(devices): + raise ValueError(f"--server_devices contains duplicates: {devices}") + + bad = [i for i in devices if i < 0 or i >= len(mem)] + if bad: + raise ValueError(f"--server_devices contains invalid GPU indices {bad}; nvidia-smi returned {len(mem)} GPUs") + return devices + + +def post(url: str, path: str, payload=None, timeout=600): + r = requests.post(url + path, json=payload or {}, timeout=timeout) + try: + body = r.json() + except Exception: + body = r.text + return r.status_code, body + + +def generate(url: str, prompt: str, max_new_tokens: int = 16) -> str: + r = requests.post( + url + "/generate", + json={ + "inputs": prompt, + "parameters": {"max_new_tokens": max_new_tokens, "do_sample": False}, + }, + timeout=120, + ) + r.raise_for_status() + data = r.json() + if isinstance(data.get("generated_text"), list): + return data["generated_text"][0] + return data.get("generated_text", json.dumps(data)) + + +def looks_garbage(text: str) -> bool: + """Heuristic: post-resume text is usually a single repeated character (e.g. '!!!!').""" + s = text.strip() + if not s: + return True + return len(set(s)) == 1 + + +# ---------------- weight-update helpers (update_weights_from_tensor) ---------------- + + +def _list_safetensor_shards(model_dir: str) -> List[str]: + shards = sorted(glob(os.path.join(model_dir, "*.safetensors"))) + if not shards: + raise RuntimeError(f"no .safetensors found under {model_dir}") + return shards + + +def _send_update_from_tensor( + url: str, + serialized_per_rank: List[str], + flush_cache: bool = False, +): + code, body = post( + url, + "/update_weights_from_tensor", + { + "serialized_named_tensors": serialized_per_rank, + "load_format": None, + "flush_cache": flush_cache, + "abort_all_requests": False, + }, + timeout=600, + ) + return code, body + + +def update_weights_from_disk_via_tensor_api( + url: str, + model_dir: str, + tp: int, + client_device: int, + batch_per_request: int = 8, + flush_cache_at_end: bool = True, +): + """ + Acts as an in-process "trainer": loads every safetensor shard onto + cuda:client_device, then ships each batch of (name, tensor) to the server + via /update_weights_from_tensor. The server worker on each TP rank receives + a CUDA IPC handle, copies into its weight buffer. + """ + banner("update_weights_from_tensor (CUDA IPC)") + # Server side patches its own copy; we patch ours so reductions can serialise + # CUDA tensors with UUID-based device addressing. + monkey_patch_torch_reductions() + torch.cuda.set_device(client_device) + device = f"cuda:{client_device}" + + shards = _list_safetensor_shards(model_dir) + print(f" found {len(shards)} safetensor shards, batch_per_request={batch_per_request}", flush=True) + + total_params = 0 + total_bytes = 0 + t0 = time.time() + for shard_idx, shard in enumerate(shards): + shard_t0 = time.time() + with safe_open(shard, framework="pt") as f: + keys = list(f.keys()) + for i in range(0, len(keys), batch_per_request): + batch_keys = keys[i : i + batch_per_request] + # Load batch onto the client GPU. .contiguous() guarantees a + # whole-tensor allocation (safetensors slices are already + # contiguous, but this is cheap insurance). + tensors = [f.get_tensor(k).to(device).contiguous() for k in batch_keys] + named: List[Tuple[str, torch.Tensor]] = list(zip(batch_keys, tensors)) + + # Same payload to every TP rank — the server clones full + # tensors per rank and lets model.load_weights handle the TP + # sharding internally (matching how update_weights_from_* + # paths are written). + blob = MultiprocessingSerializer.serialize(named, output_str=True) + serialized_per_rank = [blob] * tp + + # Last batch flushes the prefix cache so old KV from the + # previous weight version cannot poison subsequent gens. + is_last = (shard_idx == len(shards) - 1) and (i + batch_per_request >= len(keys)) + code, body = _send_update_from_tensor( + url, + serialized_per_rank, + flush_cache=(flush_cache_at_end and is_last), + ) + if code != 200: + fail(f"update batch failed: {code} {body}") + raise RuntimeError(f"update batch failed: {code} {body}") + total_params += len(batch_keys) + total_bytes += sum(t.numel() * t.element_size() for t in tensors) + # Free client-side memory before next batch — the worker has + # already cloned the data by the time post() returned. + for t in tensors: + del t + del tensors, named + torch.cuda.empty_cache() + + print( + f" shard {shard_idx+1}/{len(shards)} done " + f"(+{len(keys)} tensors, {time.time()-shard_t0:.1f}s, " + f"running total {total_params} params, {total_bytes/1e9:.1f} GB)", + flush=True, + ) + + dt = time.time() - t0 + ok(f"streamed {total_params} params, {total_bytes/1e9:.1f} GB in {dt:.1f}s") + + +# ---------------- main flow ---------------- + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--url", default="http://127.0.0.1:8000") + ap.add_argument("--model_dir", required=True) + ap.add_argument("--tp", type=int, required=True) + ap.add_argument( + "--server_devices", + default="auto", + help="comma-separated nvidia-smi GPU indices used by the server, or 'auto' to infer from memory usage", + ) + ap.add_argument( + "--client_device", + type=int, + default=2, + help="GPU index for the in-process trainer; must differ from TP worker GPUs", + ) + ap.add_argument("--prompt", default="The capital of France is") + ap.add_argument("--max_new_tokens", type=int, default=16) + ap.add_argument("--batch_per_request", type=int, default=8) + ap.add_argument("--skip_update", action="store_true", help="run only release/resume, skip the update_weights phase") + args = ap.parse_args() + + # ---------------- stage 1: baseline ---------------- + banner("baseline generate") + base_text = generate(args.url, args.prompt, args.max_new_tokens) + print(f" prompt : {args.prompt!r}") + print(f" generated: {base_text!r}") + ok("baseline generated") + + # ---------------- stage 2: release ---------------- + banner("release_memory_occupation") + before = gpu_mem_used_mib() + try: + server_devices = _resolve_server_devices(args.server_devices, args.tp, before) + except ValueError as e: + fail(str(e)) + sys.exit(1) + print(f" server GPUs : {server_devices}") + print(f" GPU mem before: {_select_gpu_mem(before, server_devices)}") + code, body = post(args.url, "/release_memory_occupation", {}) + print(f" resp: {code} {body}") + if code != 200: + fail("release failed") + sys.exit(1) + time.sleep(2) + after = gpu_mem_used_mib() + print(f" GPU mem after : {_select_gpu_mem(after, server_devices)}") + drop = sum(_select_gpu_mem(before, server_devices)) - sum(_select_gpu_mem(after, server_devices)) + if drop < 10_000: + fail(f"release did not free much memory (delta={drop} MiB)") + sys.exit(1) + ok(f"release freed ~{drop} MiB on TP GPUs") + + # ---------------- stage 3: resume ---------------- + banner("resume_memory_occupation") + code, body = post(args.url, "/resume_memory_occupation", {}) + print(f" resp: {code} {body}") + if code != 200: + fail("resume failed") + sys.exit(1) + time.sleep(2) + print(f" GPU mem after : {_select_gpu_mem(gpu_mem_used_mib(), server_devices)}") + ok("resume returned success") + + banner("post-resume generate (likely garbage without weight cpu backup)") + text_after_resume = generate(args.url, args.prompt, args.max_new_tokens) + print(f" generated: {text_after_resume!r} garbage_heuristic={looks_garbage(text_after_resume)}") + + if args.skip_update: + ok("done (skipped update_weights stage)") + return + + # ---------------- stage 4: update_weights_from_tensor ---------------- + update_weights_from_disk_via_tensor_api( + url=args.url, + model_dir=args.model_dir, + tp=args.tp, + client_device=args.client_device, + batch_per_request=args.batch_per_request, + flush_cache_at_end=True, + ) + + # ---------------- stage 5: final generate ---------------- + banner("final generate (after weight reload)") + final_text = generate(args.url, args.prompt, args.max_new_tokens) + print(f" prompt : {args.prompt!r}") + print(f" generated: {final_text!r}") + if looks_garbage(final_text): + fail("final generation still looks like garbage; weight update did not stick") + sys.exit(1) + if final_text.strip() == base_text.strip(): + ok("final output matches baseline exactly") + else: + ok("final output is sensible (differs from baseline but not garbage)") + + print(f"\n{GREEN}ALL STAGES PASSED{RESET}") + + +if __name__ == "__main__": + main() diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 0000000000..1729187856 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,265 @@ +import numpy as np +import pytest +import torch + +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + +def _skip_without_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for routing capture.") + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + mem_indexes = torch.arange(100, 110, device="cuda") + expected = np.zeros((10, 4, 8), dtype=np.uint8) + + make_capture_callback = manager.make_capture_callback_factory(mem_indexes) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + make_capture_callback(layer_idx)(topk_ids) + expected[:, layer_idx, :] = topk_ids.cpu().numpy().astype(np.uint8) + + result = manager.extract_routing_data(mem_indexes) + assert result.shape == (10, 4, 8) + assert result.dtype == np.uint8 + np.testing.assert_array_equal(result, expected) + + def test_capture_writes_to_correct_kv_positions(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + topk_ids_layer1 = topk_ids + 20 + + make_capture_callback = manager.make_capture_callback_factory(mem_indexes) + make_capture_callback(0)(topk_ids) + make_capture_callback(1)(topk_ids_layer1) + + result = manager.extract_routing_data(mem_indexes) + np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8)) + np.testing.assert_array_equal(result[:, 1, :], topk_ids_layer1.cpu().numpy().astype(np.uint8)) + + def test_capture_maps_transformer_layer_num_to_routing_slot(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=2, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + layer_num_to_moe_index={3: 0, 7: 1}, + ) + mem_indexes = torch.tensor([10, 11], device="cuda") + ids_layer3 = torch.tensor([[1, 2], [3, 4]], device="cuda") + ids_layer7 = torch.tensor([[5, 6], [7, 8]], device="cuda") + + make_capture_callback = manager.make_capture_callback_factory(mem_indexes) + make_capture_callback(3)(ids_layer3) + make_capture_callback(7)(ids_layer7) + assert make_capture_callback(4) is None + + result = manager.extract_routing_data(mem_indexes) + np.testing.assert_array_equal(result[:, 0, :], ids_layer3.cpu().numpy().astype(np.uint8)) + np.testing.assert_array_equal(result[:, 1, :], ids_layer7.cpu().numpy().astype(np.uint8)) + + def test_capture_rejects_unexpected_topk_width(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + mem_indexes = torch.tensor([10, 11], device="cuda") + topk_ids_with_shared_expert = torch.tensor([[1, 2, 32], [3, 4, 32]], device="cuda") + + with pytest.raises(AssertionError): + manager.make_capture_callback_factory(mem_indexes)(0)(topk_ids_with_shared_expert) + + def test_cuda_graph_replay_uses_copied_mem_indexes(self): + _skip_without_cuda() + + class _NoopDecodeState: + def copy_for_decode_cuda_graph(self, other): + return + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + graph_infer_state = InferStateInfo() + graph_infer_state.decode_att_state = _NoopDecodeState() + graph_infer_state.mem_index = torch.tensor([10, 11], device="cuda") + graph_infer_state.make_routing_capture_callback = manager.make_capture_callback_factory( + graph_infer_state.mem_index + ) + capture_callback = graph_infer_state.make_routing_capture_callback(0) + topk_ids = torch.tensor([[1, 2], [3, 4]], device="cuda") + + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + with torch.cuda.graph(graph): + capture_callback(topk_ids) + + graph.replay() + result = manager.extract_routing_data(graph_infer_state.mem_index) + np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8)) + + new_infer_state = InferStateInfo() + new_infer_state.decode_att_state = _NoopDecodeState() + new_infer_state.mem_index = torch.tensor([20, 21], device="cuda") + new_topk_ids = torch.tensor([[5, 6], [7, 8]], device="cuda") + graph_infer_state.copy_for_cuda_graph(new_infer_state) + topk_ids.copy_(new_topk_ids) + + graph.replay() + + result = manager.extract_routing_data(new_infer_state.mem_index) + np.testing.assert_array_equal(result[:, 0, :], new_topk_ids.cpu().numpy().astype(np.uint8)) + + def test_microbatch_isolation(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + mem0 = torch.tensor([10, 11], device="cuda") + mem1 = torch.tensor([20, 21], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + + capture0 = manager.make_capture_callback_factory(mem0)(0) + capture1 = manager.make_capture_callback_factory(mem1)(0) + capture0(ids_0) + capture1(ids_1) + + result0 = manager.extract_routing_data(mem0) + result1 = manager.extract_routing_data(mem1) + assert result0[0, 0, 0] == 1 + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_uint8(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.uint8 + assert manager.np_dtype == np.uint8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=257, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_uint8_values(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=256, + kv_cache_size=64, + max_capture_tokens=16, + ) + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 255, 3]], device="cuda") + + manager.make_capture_callback_factory(mem_indexes)(0)(topk_ids) + + result = manager.extract_routing_data(mem_indexes) + np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8)) + + def test_routing_buffer_and_pointer_shape(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=256, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.routing_buffer.shape == (2048, 48, 8) + assert manager.routing_buffer.dtype == torch.uint8 + assert manager.routing_buffer.device.type == "cpu" + assert manager.routing_buffer.is_pinned() + assert manager.routing_buffer_ptr.shape == (1,) + assert manager.routing_buffer_ptr.dtype == torch.uint64 + assert manager.routing_buffer_ptr.device.type == "cuda" + + def test_partial_token_capture(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") + + manager.make_capture_callback_factory(mem_indexes[:3])(0)(topk_ids) + + result_written = manager.extract_routing_data(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8)) + + result_unwritten = manager.extract_routing_data(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[:, 0, :], np.zeros((2, 2), dtype=np.uint8)) + + def test_capture_does_not_allocate_capture_buffer(self): + _skip_without_cuda() + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert not hasattr(manager, "_capture_buffer") diff --git a/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py b/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py new file mode 100644 index 0000000000..2bd9bdf06c --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py @@ -0,0 +1,133 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.routing_capture import scatter_routing_topk_to_cpu + + +def _skip_without_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + +@pytest.mark.parametrize("dtype,dtype_id,max_value", [(torch.uint8, 1, 255), (torch.int16, 2, 1024)]) +def test_scatter_routing_topk_to_pinned_cpu(dtype, dtype_id, max_value): + _skip_without_cuda() + + num_tokens = 5 + num_moe_layers = 3 + topk = 4 + kv_cache_size = 32 + moe_layer_index = 1 + + topk_ids = torch.randint(0, max_value, (num_tokens, topk), dtype=torch.int64, device="cuda") + mem_indexes = torch.tensor([17, 3, 29, 8, 11], dtype=torch.int32, device="cuda") + routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=dtype, + device="cpu", + pin_memory=True, + ) + routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda") + + scatter_routing_topk_to_cpu( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=routing_buffer_ptr, + moe_layer_index=moe_layer_index, + num_moe_layers=num_moe_layers, + topk=topk, + dtype_id=dtype_id, + ) + torch.cuda.synchronize() + + expected = torch.zeros_like(routing_buffer) + expected[mem_indexes.cpu().long(), moe_layer_index, :] = topk_ids.cpu().to(dtype) + assert torch.equal(routing_buffer, expected) + + +def test_scatter_routing_topk_respects_layer_index(): + _skip_without_cuda() + + num_tokens = 3 + num_moe_layers = 2 + topk = 2 + kv_cache_size = 16 + + topk_ids = torch.arange(num_tokens * topk, dtype=torch.int64, device="cuda").view(num_tokens, topk) + mem_indexes = torch.tensor([10, 4, 13], dtype=torch.int64, device="cuda") + routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=torch.int16, + device="cpu", + pin_memory=True, + ) + routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda") + + scatter_routing_topk_to_cpu( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=routing_buffer_ptr, + moe_layer_index=1, + num_moe_layers=num_moe_layers, + topk=topk, + dtype_id=2, + ) + torch.cuda.synchronize() + + expected = torch.zeros_like(routing_buffer) + expected[mem_indexes.cpu(), 1, :] = topk_ids.cpu().to(torch.int16) + assert torch.equal(routing_buffer, expected) + + +def test_scatter_routing_topk_is_cuda_graph_capturable(): + _skip_without_cuda() + + num_tokens = 4 + num_moe_layers = 2 + topk = 3 + kv_cache_size = 16 + + topk_ids = torch.arange(num_tokens * topk, dtype=torch.int64, device="cuda").view(num_tokens, topk) + mem_indexes = torch.tensor([2, 4, 6, 8], dtype=torch.int32, device="cuda") + routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda") + + scatter_routing_topk_to_cpu( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=routing_buffer_ptr, + moe_layer_index=0, + num_moe_layers=num_moe_layers, + topk=topk, + dtype_id=1, + ) + torch.cuda.synchronize() + routing_buffer.zero_() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + scatter_routing_topk_to_cpu( + topk_ids=topk_ids, + mem_indexes=mem_indexes, + routing_buffer_ptr=routing_buffer_ptr, + moe_layer_index=0, + num_moe_layers=num_moe_layers, + topk=topk, + dtype_id=1, + ) + + graph.replay() + torch.cuda.synchronize() + + expected = torch.zeros_like(routing_buffer) + expected[mem_indexes.cpu(), 0, :] = topk_ids.cpu().to(torch.uint8) + assert torch.equal(routing_buffer, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d8..dfeda0b6f7 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main()