diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 94f9d4c1a2..5437c24363 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -33,8 +33,13 @@
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend
+from . import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -90,6 +95,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
@@ -103,18 +109,22 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()
- self._init_weights()
- self._init_req_manager()
- self._init_mem_manager()
+ enable_weight_cpu_backup = self.args.enable_weight_cpu_backup
+ with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=enable_weight_cpu_backup):
+ self._init_weights()
+ with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
+ self._init_req_manager()
+ self._init_mem_manager()
+
# 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state
# 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值
self.req_manager.mem_manager = self.mem_manager
-
self._check_mem_size()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
- self._load_hf_weights()
+ self._init_routing_capture()
+ self.load_weights(self.weight_dict)
self._init_att_backend()
self._init_att_backend1()
@@ -172,13 +182,14 @@ def _init_weights(self, start_layer_index=0):
]
return
- def _load_hf_weights(self):
+ def load_weights(self, weight_dict: dict):
+ assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None"
load_hf_weights(
- self.data_type,
+ data_type=self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
- weight_dict=self.weight_dict,
+ weight_dict=weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
@@ -289,6 +300,17 @@ def _init_prefill_cuda_graph(self):
def _init_custom(self):
pass
+ def _init_routing_capture(self):
+ if not self.args.enable_return_routed_experts:
+ return
+ if _routing_mgr.g_routing_capture_manager is not None:
+ # MTP draft models share the main model process and KV cache, so they
+ # should reuse the routing capture manager initialized by the main model.
+ logger.info("RoutingCaptureManager already initialized, skip routing capture init.")
+ return
+ _routing_mgr.init_routing_capture(self)
+ return
+
@torch.no_grad()
def forward(self, model_input: ModelInput):
model_input.to_cuda()
@@ -333,6 +355,11 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.mem_index = model_input.mem_indexes
infer_state.microbatch_index = microbatch_index
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)
+ mgr = _routing_mgr.g_routing_capture_manager
+ if mgr is not None:
+ # Build a callback that records each MoE layer's top-k experts into
+ # the routing buffer at this forward's KV-cache positions.
+ infer_state.make_routing_capture_callback = mgr.make_capture_callback_factory(infer_state.mem_index)
# 特殊模型,特殊模式的特定变量初始化操作。
infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens
@@ -1031,6 +1058,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
+ torch.cuda.empty_cache()
return
def autotune_layers(self):
@@ -1165,6 +1193,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
+ del b_mtp_index
+ del b_prefill_start_loc
+ del b_q_seq_len
torch.cuda.empty_cache()
return
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index 782150661e..5e8036ee81 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -8,6 +8,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .infer_struct import InferStateInfo
@@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int =
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
@@ -94,7 +99,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
if param_name not in pure_para_set:
delattr(infer_state, param_name)
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
@@ -128,7 +133,7 @@ def _capture_decode_overlap(
if para_name not in pure_para_set1:
delattr(infer_state1, para_name)
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py
index 711484c835..e09452b5ab 100755
--- a/lightllm/common/basemodel/infer_struct.py
+++ b/lightllm/common/basemodel/infer_struct.py
@@ -100,6 +100,9 @@ def __init__(self):
self.dp_output_split_sizes: List[List[int]] = None
self.dp_input_split_sizes: List[List[int]] = None
+ # Optional hook for recording MoE routing top-k ids during forward.
+ self.make_routing_capture_callback = None
+
def init_some_extra_state(self, model):
if self.is_prefill:
(
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
index fca9b80fcf..1e7f94b314 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
@@ -124,6 +124,12 @@ def _init_parallel_params(self):
self.expert_idx_to_local_idx = {expert_idx: i for (i, expert_idx) in enumerate(self.local_expert_ids)}
self.rexpert_idx_to_local_idx = {}
+ def _make_routing_capture_callback(self, infer_state):
+ make_routing_capture_callback = getattr(infer_state, "make_routing_capture_callback", None)
+ if make_routing_capture_callback is None:
+ return None
+ return make_routing_capture_callback(self.layer_num_)
+
def experts(
self,
input_tensor: torch.Tensor,
@@ -134,8 +140,9 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ infer_state=None,
) -> torch.Tensor:
- """Backward compatible method that routes to platform-specific implementation."""
+ routing_capture_callback = self._make_routing_capture_callback(infer_state)
return self.fuse_moe_impl(
input_tensor=input_tensor,
router_logits=router_logits,
@@ -149,6 +156,7 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
+ routing_capture_callback=routing_capture_callback,
per_expert_scale=self.per_expert_scale,
)
@@ -317,6 +325,7 @@ def _create_weight(self):
device_id=self.device_id_,
num_experts=self.local_n_routed_experts,
)
+ self.w1, self.w3 = w13_param_list
self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0])
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)
@@ -339,6 +348,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack):
return weight_list
def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]):
+ # for merged weights
+ self._load_merge_weight(weights)
# Load each expert with TP slicing
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
@@ -363,6 +374,7 @@ def _load_expert(
w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"
+
row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_weight in weights:
@@ -372,6 +384,19 @@ def _load_expert(
if w2_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx])
+ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]):
+ w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}"
+ w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}"
+ w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}"
+ row_slice_func = self.row_slicer._slice_weight
+ col_slice_func = self.col_slicer._slice_weight
+ if w1_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1)
+ if w2_merge_weight in weights:
+ self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2)
+ if w3_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3)
+
def _load_expert_scale(
self,
expert_idx: int,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
index 6ed0cef0b4..2e82571f51 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
@@ -144,10 +144,15 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ infer_state=None,
):
topk_weights, topk_ids = self._router(router_logits, top_k)
+ routing_capture_callback = self._make_routing_capture_callback(infer_state)
+ if routing_capture_callback is not None:
+ routing_capture_callback(topk_ids)
+
w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
index dd6f9a6880..443a271cd4 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
@@ -1,10 +1,10 @@
import torch
from abc import abstractmethod
+from typing import Callable, Optional
from lightllm.common.quantization.quantize_method import (
WeightPack,
QuantizationMethod,
)
-from typing import Optional
from lightllm.utils.dist_utils import (
get_global_rank,
get_global_world_size,
@@ -62,6 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ routing_capture_callback: Optional[Callable[[torch.Tensor], None]] = None,
per_expert_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pass
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
index a0d30547a3..b5d231f58e 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
@@ -1,5 +1,5 @@
import torch
-from typing import Optional
+from typing import Callable, Optional
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
@@ -62,6 +62,7 @@ def _select_experts(
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
+ routed_topk_ids = topk_ids
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
@@ -83,7 +84,7 @@ def _select_experts(
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1)
- return topk_weights, topk_ids
+ return topk_weights, topk_ids, routed_topk_ids
def _fused_experts(
self,
@@ -128,9 +129,10 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ routing_capture_callback: Optional[Callable[[torch.Tensor], None]] = None,
per_expert_scale: Optional[torch.Tensor] = None,
):
- topk_weights, topk_ids = self._select_experts(
+ selected_experts = self._select_experts(
input_tensor=input_tensor,
router_logits=router_logits,
correction_bias=correction_bias,
@@ -142,6 +144,15 @@ def __call__(
scoring_func=scoring_func,
per_expert_scale=per_expert_scale,
)
+ if len(selected_experts) == 2:
+ topk_weights, topk_ids = selected_experts
+ routed_topk_ids = topk_ids
+ else:
+ topk_weights, topk_ids, routed_topk_ids = selected_experts
+
+ if routing_capture_callback is not None:
+ routing_capture_callback(routed_topk_ids)
+
output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
index ddbf98a866..067c1c8ca9 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
@@ -28,6 +28,10 @@ def _get_slice_start_end(self, size: int) -> Tuple[int, int]:
end = start + tp_size
return start, end
+ def _assert_weight_ndim(self, tensor: torch.Tensor) -> None:
+ # 2D: 普通 linear (out, in); 3D: MoE 合并权重 (num_experts, out, in)。
+ assert tensor.dim() in (2, 3), f"expect weight ndim in (2, 3), got shape {tuple(tensor.shape)}"
+
class SliceMixinTpl(SliceMixinBase):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1):
@@ -46,18 +50,20 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
raise NotImplementedError("slice_weight_zero_point must implement this method")
-# 默认weight 的shape是 outxin,这也是目前最通用的约定。
-# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
+# 默认 weight 的 shape 末两维是 (out, in),普通 linear 是 2D (out, in),
+# MoE 合并权重则是 3D (num_experts, out, in),统一通过 `...` 处理任意前导维。
+# 约定 row-wise 沿着 out 维(倒数第二维)切分,col-wise 沿着 in 维(最后一维)切分。
class RowSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1):
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight)
assert (
- weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[0])
- return weight[start:end, :]
+ weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-2])
+ return weight[..., start:end, :]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert (
@@ -74,18 +80,20 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight_scale)
assert (
- weight_scale.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[0])
- return weight_scale[start:end]
+ weight_scale.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-2])
+ return weight_scale[..., start:end, :]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight_zero_point)
assert (
- weight_zero_point.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[0])
- return weight_zero_point[start:end]
+ weight_zero_point.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-2])
+ return weight_zero_point[..., start:end, :]
class ColSliceMixin(SliceMixinTpl):
@@ -93,11 +101,12 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight)
assert (
- weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[1])
- return weight[:, start:end]
+ weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-1])
+ return weight[..., start:end]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_ * self.repeat_times_
@@ -108,18 +117,20 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight_scale)
assert (
- weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[1])
- return weight_scale[:, start:end]
+ weight_scale.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-1])
+ return weight_scale[..., start:end]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
+ self._assert_weight_ndim(weight_zero_point)
assert (
- weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[1])
- return weight_zero_point[:, start:end]
+ weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-1])
+ return weight_zero_point[..., start:end]
# awq 的量化权重是inxout存储格式,需要定制实现。
diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py
new file mode 100644
index 0000000000..a154415b59
--- /dev/null
+++ b/lightllm/common/basemodel/routing_manager.py
@@ -0,0 +1,276 @@
+import atexit
+import json
+import os
+import torch
+import numpy as np
+from multiprocessing import shared_memory
+from typing import Dict, List, Optional, Tuple
+from lightllm.common.basemodel.triton_kernel.routing_capture import scatter_routing_topk_to_cpu
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.dist_utils import get_current_rank_in_dp
+from lightllm.utils.envs_utils import get_unique_server_name
+
+logger = init_logger(__name__)
+
+
+def routing_dtype_id_to_np(dtype_id: int):
+ if dtype_id == 1:
+ return np.uint8
+ elif dtype_id == 2:
+ return np.int16
+ return np.int32
+
+
+def _get_model_text_config(config: dict) -> dict:
+ return config.get("text_config", config)
+
+
+def _get_num_moe_layers_from_config(config: dict) -> int:
+ num_layers = config.get("num_hidden_layers", config.get("n_layer", config.get("num_layers", 0)))
+ num_experts = config.get("n_routed_experts", config.get("num_experts", config.get("num_local_experts", 0)))
+ if num_layers <= 0 or not num_experts:
+ return 0
+
+ if "first_k_dense_replace" in config:
+ first_k_dense_replace = config.get("first_k_dense_replace", 0)
+ moe_layer_freq = config.get("moe_layer_freq", 1)
+ return sum(
+ 1
+ for layer_num in range(num_layers)
+ if layer_num >= first_k_dense_replace and layer_num % moe_layer_freq == 0
+ )
+
+ if "mlp_only_layers" in config or "decoder_sparse_step" in config:
+ mlp_only_layers = set(config.get("mlp_only_layers", []))
+ decoder_sparse_step = config.get("decoder_sparse_step", 1)
+ return sum(
+ 1
+ for layer_num in range(num_layers)
+ if layer_num not in mlp_only_layers and (layer_num + 1) % decoder_sparse_step == 0
+ )
+
+ if config.get("enable_moe_block", False):
+ return num_layers
+
+ return num_layers
+
+
+def get_routing_config_from_model_dir(model_dir: str) -> Optional[Tuple[int, int, int]]:
+ with open(os.path.join(model_dir, "config.json"), "r") as json_file:
+ config = _get_model_text_config(json.load(json_file))
+
+ num_moe_layers = _get_num_moe_layers_from_config(config)
+ topk = config.get("num_experts_per_tok", config.get("top_k_experts", 0))
+ num_experts = config.get("n_routed_experts", config.get("num_experts", config.get("num_local_experts", 0)))
+ if num_moe_layers <= 0 or topk <= 0 or not num_experts:
+ return None
+
+ dtype_id = 1 if num_experts <= 256 else 2
+ return num_moe_layers, topk, dtype_id
+
+
+class RoutingCaptureManager:
+ def __init__(
+ self,
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+ layer_num_to_moe_index: Optional[Dict[int, int]] = None,
+ ):
+ self.num_moe_layers = num_moe_layers
+ self.topk = topk
+ self.num_experts = num_experts
+ self.kv_cache_size = kv_cache_size
+ self.max_capture_tokens = max_capture_tokens
+ self.layer_num_to_moe_index = layer_num_to_moe_index or {i: i for i in range(num_moe_layers)}
+
+ self.dtype = torch.uint8 if num_experts <= 256 else torch.int16
+ dtype_bytes = 1 if self.dtype == torch.uint8 else 2
+
+ # Shape: (kv_cache_size, num_moe_layers, topk). Pinned CPU memory saves GPU memory
+ # while allowing the Triton scatter kernel to write without a synchronous D2H copy.
+ routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
+ self.routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=self.dtype,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.routing_buffer_ptr = torch.tensor([self.routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda")
+
+ dtype_name = "uint8" if self.dtype == torch.uint8 else "int16"
+ logger.info(
+ f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
+ f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
+ f"dtype={dtype_name}"
+ )
+
+ @property
+ def np_dtype(self):
+ return np.uint8 if self.dtype == torch.uint8 else np.int16
+
+ @property
+ def dtype_id(self) -> int:
+ return 1 if self.dtype == torch.uint8 else 2
+
+ def make_capture_callback_factory(self, mem_indexes: torch.Tensor):
+ if not mem_indexes.is_cuda:
+ mem_indexes = mem_indexes.cuda(non_blocking=True)
+
+ def make_capture_callback(layer_num: int):
+ routing_layer_index = self.layer_num_to_moe_index.get(layer_num)
+ if routing_layer_index is None:
+ return None
+
+ def capture_callback(topk_ids: torch.Tensor) -> None:
+ self.capture(routing_layer_index=routing_layer_index, topk_ids=topk_ids, mem_indexes=mem_indexes)
+
+ return capture_callback
+
+ return make_capture_callback
+
+ def capture(self, routing_layer_index: int, topk_ids: torch.Tensor, mem_indexes: torch.Tensor) -> None:
+ assert topk_ids.dim() == 2
+ assert topk_ids.shape[1] == self.topk
+ assert mem_indexes.shape[0] >= topk_ids.shape[0]
+ scatter_routing_topk_to_cpu(
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=self.routing_buffer_ptr,
+ moe_layer_index=routing_layer_index,
+ num_moe_layers=self.num_moe_layers,
+ topk=self.topk,
+ dtype_id=self.dtype_id,
+ )
+
+ def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray:
+ torch.cuda.current_stream().synchronize()
+ cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes
+ return self.routing_buffer[cpu_indexes, :, :].numpy()
+
+
+g_routing_capture_manager: Optional[RoutingCaptureManager] = None
+
+
+def create_routing_capture_manager(
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+ layer_num_to_moe_index: Optional[Dict[int, int]] = None,
+) -> None:
+ global g_routing_capture_manager
+ assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
+ g_routing_capture_manager = RoutingCaptureManager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=kv_cache_size,
+ max_capture_tokens=max_capture_tokens,
+ layer_num_to_moe_index=layer_num_to_moe_index,
+ )
+
+
+def _get_moe_layer_nums(model) -> List[int]:
+ moe_layer_nums = []
+ for layer_weight in getattr(model, "trans_layers_weight", []):
+ is_moe = getattr(layer_weight, "is_moe", None)
+ if is_moe is None:
+ is_moe = hasattr(layer_weight, "experts")
+ if is_moe:
+ moe_layer_nums.append(layer_weight.layer_num_)
+ return moe_layer_nums
+
+
+def cleanup_routing_shm_pool() -> None:
+ """Unlink all pre-allocated routing SHM segments. Called at server shutdown."""
+ try:
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+ except Exception:
+ return
+
+ service_name = get_unique_server_name()
+
+ for i in range(args.running_max_req_size):
+ name = f"{service_name}_shm_routing_{i}"
+ try:
+ shm = shared_memory.SharedMemory(name=name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+ config_name = f"{service_name}_routing_config"
+ try:
+ shm = shared_memory.SharedMemory(name=config_name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+
+def init_routing_capture(model, num_moe_layers: Optional[int] = None) -> None:
+ moe_layer_nums = _get_moe_layer_nums(model)
+ if num_moe_layers is None:
+ num_moe_layers = len(moe_layer_nums)
+ elif moe_layer_nums:
+ assert num_moe_layers == len(moe_layer_nums)
+ else:
+ moe_layer_nums = list(range(num_moe_layers))
+ layer_num_to_moe_index = {layer_num: moe_index for moe_index, layer_num in enumerate(moe_layer_nums)}
+
+ dp_rank = get_current_rank_in_dp()
+ logger.info(
+ f"init_routing_capture called: num_moe_layers={num_moe_layers}, "
+ f"moe_layer_nums={moe_layer_nums}, dp_rank={dp_rank}"
+ )
+ if dp_rank != 0:
+ logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}")
+ return
+
+ if num_moe_layers == 0:
+ logger.warning(
+ "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled."
+ )
+ return
+
+ num_experts = model.config.get(
+ "n_routed_experts",
+ model.config.get("num_experts", model.config.get("num_local_experts", 0)),
+ )
+ topk = model.config.get("num_experts_per_tok", 0)
+ assert num_experts > 0 and topk > 0
+
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+
+ # Capture buffer must fit the max tokens in any single forward call.
+ # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size.
+ batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192
+ max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size)
+
+ logger.info(
+ f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
+ f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}"
+ )
+
+ create_routing_capture_manager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=model.mem_manager.size + 1,
+ max_capture_tokens=max_capture_tokens,
+ layer_num_to_moe_index=layer_num_to_moe_index,
+ )
+
+ logger.info(
+ f"Routing capture config set: num_moe_layers={num_moe_layers}, topk={topk}, "
+ f"dtype_id={g_routing_capture_manager.dtype_id}, max_tokens={args.max_req_total_len}"
+ )
+ atexit.register(cleanup_routing_shm_pool)
diff --git a/lightllm/common/basemodel/triton_kernel/routing_capture.py b/lightllm/common/basemodel/triton_kernel/routing_capture.py
new file mode 100644
index 0000000000..d0fa822058
--- /dev/null
+++ b/lightllm/common/basemodel/triton_kernel/routing_capture.py
@@ -0,0 +1,74 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _scatter_routing_topk_to_cpu(
+ topk_ids,
+ mem_indexes,
+ routing_buffer_ptr,
+ total_size,
+ moe_layer_index: tl.constexpr,
+ layer_topk_size: tl.constexpr,
+ topk: tl.constexpr,
+ dtype_id: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offsets < total_size
+
+ token_offsets = offsets // topk
+ topk_offsets = offsets - token_offsets * topk
+ mem_index = tl.load(mem_indexes + token_offsets, mask=mask, other=-1).to(tl.int64)
+ data = tl.load(topk_ids + offsets, mask=mask, other=0)
+
+ dst_offsets = mem_index * layer_topk_size + moe_layer_index * topk + topk_offsets
+ if dtype_id == 1:
+ dst_ptr = tl.load(routing_buffer_ptr).to(tl.pointer_type(tl.uint8))
+ tl.store(dst_ptr + dst_offsets, data.to(tl.uint8), mask=mask)
+ else:
+ dst_ptr = tl.load(routing_buffer_ptr).to(tl.pointer_type(tl.int16))
+ tl.store(dst_ptr + dst_offsets, data.to(tl.int16), mask=mask)
+
+
+def scatter_routing_topk_to_cpu(
+ topk_ids: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ routing_buffer_ptr: torch.Tensor,
+ moe_layer_index: int,
+ num_moe_layers: int,
+ topk: int,
+ dtype_id: int,
+):
+ assert topk_ids.is_cuda
+ assert mem_indexes.is_cuda
+ assert mem_indexes.is_contiguous()
+ assert routing_buffer_ptr.is_cuda
+ assert routing_buffer_ptr.dtype == torch.uint64
+ assert routing_buffer_ptr.numel() == 1
+ assert topk_ids.dim() == 2
+ assert topk_ids.shape[1] == topk
+ assert topk_ids.is_contiguous()
+ assert 0 <= moe_layer_index < num_moe_layers
+
+ num_tokens = topk_ids.shape[0]
+ layer_topk_size = num_moe_layers * topk
+ total_size = num_tokens * topk
+ if total_size == 0:
+ return
+
+ BLOCK = 1024
+ grid = (triton.cdiv(total_size, BLOCK),)
+ _scatter_routing_topk_to_cpu[grid](
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=routing_buffer_ptr,
+ total_size=total_size,
+ moe_layer_index=moe_layer_index,
+ layer_topk_size=layer_topk_size,
+ topk=topk,
+ dtype_id=dtype_id,
+ BLOCK=BLOCK,
+ )
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..c75c871c72
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..14026090e6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
new file mode 100644
index 0000000000..939c939523
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 16
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..13ba4ba8e5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ee316f610b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e027701092
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ddda23d257
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..560ca6c09d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..0713de7996
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e950ff0954
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..7f479b8382
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..b3051c6584
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..fdb3212216
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..a94e669353
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..441421fd5d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..864d1d3f18
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "2048": {
+ "BLOCK_SIZE": 4096,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..bcf56e01f7
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ba1dc8a75d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "3072": {
+ "BLOCK_SIZE": 2048,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..6f109e1c6e
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "5120": {
+ "BLOCK_SIZE": 32768,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..198a196dfb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "2048": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..537c7a90eb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "256": {
+ "BLOCK_SIZE": 512,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..9a6dcb6fbf
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "4096": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..df501847ec
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "5120": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py
index c62a2572ff..4cc6453d12 100644
--- a/lightllm/common/triton_utils/autotuner.py
+++ b/lightllm/common/triton_utils/autotuner.py
@@ -11,7 +11,7 @@
from frozendict import frozendict
from lightllm.utils.device_utils import get_current_device_name
from lightllm.utils.log_utils import init_logger
-from typing import Callable, Optional, Union, List
+from typing import Callable, List
from lightllm.utils.envs_utils import get_triton_autotune_level
from lightllm.common.kernel_config import KernelConfigs
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
@@ -106,14 +106,6 @@ def __init__(
self.configs_gen_func = configs_gen_func
self.kernel_name = kernel_name
- self.cache_dir = os.path.join(
- Path(__file__).parent,
- "autotune_kernel_configs",
- get_triton_version(),
- get_current_device_name(),
- self.kernel_name,
- )
- os.makedirs(self.cache_dir, exist_ok=True)
self.fn = fn
self.static_key_func = static_key_func
self.run_key_func = run_key_func
@@ -209,6 +201,25 @@ def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
+ @property
+ def cache_dir(self) -> str:
+ if not hasattr(self, "_cache_dir"):
+ device_name = get_current_device_name()
+ if device_name is None:
+ raise RuntimeError(
+ f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device "
+ f"to resolve its cache directory, but torch.cuda.is_available() is False."
+ )
+ self._cache_dir = os.path.join(
+ Path(__file__).parent,
+ "autotune_kernel_configs",
+ get_triton_version(),
+ device_name,
+ self.kernel_name,
+ )
+ os.makedirs(self._cache_dir, exist_ok=True)
+ return self._cache_dir
+
def _try_load_cache(self, static_key):
if static_key in self.cached_configs:
return False
diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
index be819c94a0..dae79cc8a6 100644
--- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
@@ -232,6 +232,7 @@ def _moe_ffn_tp(
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
+ infer_state=infer_state,
)
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
@@ -259,6 +260,7 @@ def _moe_ffn_edp(
topk_group=self.topk_group,
num_expert_group=self.n_group,
is_prefill=infer_state.is_prefill,
+ infer_state=infer_state,
)
if self.n_shared_experts is not None:
diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py
index 015b526fbc..2f0c01dbf6 100644
--- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py
@@ -300,6 +300,7 @@ def _ffn_moe(self, input, router_logits, infer_state: InferStateInfo, layer_weig
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
+ infer_state=infer_state,
)
moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state)
return moe_out
diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
index b27ea8fd2d..490d2dc4c5 100644
--- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
@@ -52,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ infer_state=infer_state,
)
hidden_states = hidden_states.view(num_tokens, hidden_dim)
return self._tpsp_reduce(input=hidden_states, infer_state=infer_state)
diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py
deleted file mode 100644
index b0e27ac1de..0000000000
--- a/lightllm/models/mixtral/layer_infer/_custom_ops.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import functools
-import json
-import os
-from typing import Any, Dict, Optional, Tuple
-
-import torch
-import triton
-import triton.language as tl
-from lightllm.utils.log_utils import init_logger
-
-logger = init_logger(__name__)
-
-# Pytorch version
-# Triton version in progress
-def topk_softmax(
- topk_weights,
- topk_ids,
- token_expert_indicies,
- gating_output,
- topk=2,
-):
- scores = torch.softmax(gating_output, dim=-1)
- topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)
- return topk_weights, topk_ids
-
-
-def fused_topk(
- hidden_states: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- renormalize: bool,
- alloc_tensor_func=torch.empty,
-):
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
-
- M, _ = hidden_states.shape
-
- topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device)
- topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk)
- del token_expert_indicies # Not used. Will be used in the future.
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- return topk_weights, topk_ids
diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
index 0cf651598a..8134dc266d 100644
--- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
@@ -1,9 +1,6 @@
-import os
import torch
-import torch.nn.functional as F
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
-from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk
from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight
@@ -21,25 +18,14 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor
num_tokens, hidden_dim = hidden_states.shape
router_logits = layer_weight.moe_gate.mm(hidden_states)
- topk_weights, topk_ids = fused_topk(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=self.num_experts_per_tok,
+ layer_weight.experts.experts(
+ hidden_states,
+ router_logits=router_logits,
+ top_k=self.num_experts_per_tok,
renormalize=self.renormalize,
- alloc_tensor_func=self.alloc_tensor,
+ use_grouped_topk=False,
+ topk_group=None,
+ num_expert_group=None,
+ infer_state=infer_state,
)
- from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl
-
- ffn2_out = fused_experts_impl(
- hidden_states=hidden_states,
- w1=layer_weight.experts.w1[0],
- w2=layer_weight.experts.w2[0],
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_fp8_w8a8=False,
- w1_scale=None,
- w2_scale=None,
- alloc_tensor_func=self.alloc_tensor,
- )
- return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state)
+ return hidden_states.view(num_tokens, hidden_dim)
diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py
index 237c4ad897..c94135573b 100644
--- a/lightllm/models/qwen2_vl/model.py
+++ b/lightllm/models/qwen2_vl/model.py
@@ -12,6 +12,7 @@
from .vision_process import smart_resize
from lightllm.models.qwen2.model import Qwen2TpPartModel
import os
+from typing import Union, List
# Warp of the origal tokenizer
class QWen2VLTokenizer(BaseMultiModalTokenizer):
@@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem):
def get_audio_token_length(self, audio: AudioItem):
raise NotImplementedError
- def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
-
- origin_ids = self.tokenizer.encode(prompt)
+ def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs):
+ if isinstance(prompt, str):
+ origin_ids = self.tokenizer.encode(prompt)
+ elif isinstance(prompt, list):
+ origin_ids = prompt
+ else:
+ raise ValueError(f"Unsupported prompt type: {type(prompt)}")
#
->
origin_ids = [token for token in origin_ids if token != self.image_token_id]
diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
index 8879aa2d27..7edfd5a6f9 100644
--- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
@@ -86,6 +86,7 @@ def _moe_ffn_tp(
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ infer_state=infer_state,
)
return hidden_states.view(num_tokens, hidden_dim)
@@ -105,6 +106,7 @@ def _moe_ffn_edp(
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
+ infer_state=infer_state,
)
ep_output = ep_output.view(token_num, hidden_dim)
diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
index e4d80e6ff9..60bf0e6b76 100644
--- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
@@ -135,6 +135,7 @@ def _moe_ffn_tp(
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ infer_state=infer_state,
)
hidden_states = hidden_states.view(num_tokens, hidden_dim)
hidden_states.add_(shared_expert_out)
@@ -156,6 +157,7 @@ def _moe_ffn_edp(
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
+ infer_state=infer_state,
)
ep_output = ep_output.view(token_num, hidden_dim)
ep_output.add_(shared_expert_out)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 1bdf8f3427..1ef6375c3e 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -1,8 +1,7 @@
import argparse
-def make_argument_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser()
+def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
@@ -80,6 +79,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
when a llm infer node start to set this params, the visual infer module will start a
proxy module use config server to find remote vit infer nodes to infer img""",
)
+ parser.add_argument(
+ "--rl_rpyc_port",
+ type=int,
+ default=None,
+ help="The router RL control RPyC port. If unset, LightLLM will allocate one automatically.",
+ )
parser.add_argument(
"--pd_kv_page_num",
type=int,
@@ -256,6 +261,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch"
)
+ parser.add_argument(
+ "--lightllm_instance_id",
+ type=int,
+ default=0,
+ help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.",
+ )
parser.add_argument(
"--use_config_server_to_init_nccl",
action="store_true",
@@ -769,6 +780,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
+ parser.add_argument(
+ "--enable_torch_memory_saver",
+ action="store_true",
+ help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""",
+ )
+ parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""")
parser.add_argument(
"--disk_cache_dir",
type=str,
@@ -842,6 +859,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
If the op is not implemented for the platform and the hardware support triton,
it will use triton implementation.""",
)
+ parser.add_argument(
+ "--enable_return_routed_experts",
+ action="store_true",
+ default=False,
+ help="Enable returning routed expert indices for MoE models (R3 feature).",
+ )
parser.add_argument(
"--enable_profiling",
type=str,
@@ -858,3 +881,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""",
)
return parser
+
+
+def make_argument_parser() -> argparse.ArgumentParser:
+ return add_cli_args(argparse.ArgumentParser())
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 270e2a8cfd..3393fec4a9 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -34,7 +34,7 @@
import uuid
from PIL import Image
import multiprocessing as mp
-from typing import AsyncGenerator, Union
+from typing import Any, AsyncGenerator, Union
from typing import Callable
from lightllm.server import TokenLoad
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
@@ -50,6 +50,7 @@
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq
from dataclasses import dataclass
from .api_openai import chat_completions_impl, completions_impl
@@ -61,6 +62,16 @@
ModelCard,
ModelListResponse,
)
+from .io_struct import (
+ AbortReq,
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+ UpdateWeightsFromIPCReq,
+ GeneralModelToHttpRpcRsp,
+)
from .build_prompt import build_prompt, init_tokenizer
logger = init_logger(__name__)
@@ -188,6 +199,22 @@ def get_model_name():
return {"model_name": g_objs.args.model_name}
+@app.get("/get_server_info")
+@app.post("/get_server_info")
+def get_server_info():
+ # 将 StartArgs 转换为字典格式
+ from dataclasses import asdict
+
+ server_info: dict[str, Any] = asdict(g_objs.args)
+ return {**server_info}
+
+
+@app.get("/get_weight_version")
+@app.post("/get_weight_version")
+def get_weight_version():
+ return {"weight_version": g_objs.args.weight_version}
+
+
@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
@@ -411,6 +438,91 @@ async def metrics() -> Response:
return response
+@app.post("/abort_request")
+async def abort_request(request: AbortReq, raw_request: Request):
+ """Abort a request."""
+ try:
+ success, msg = await g_objs.httpserver_manager.abort_request(request)
+ if not success:
+ return create_error_response(HTTPStatus.REQUEST_TIMEOUT, msg, err_type="AbortRequestTimeout")
+ return Response(status_code=200)
+ except Exception as e:
+ logger.error("abort_request error occurred: %s", str(e), exc_info=True)
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+async def handle_request_common(request_obj, handler):
+ try:
+ ret: GeneralModelToHttpRpcRsp = await handler(request_obj)
+ if ret.success:
+ return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200)
+ else:
+ return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg)
+ except Exception as e:
+ logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True)
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+@app.post("/init_weights_update_group")
+async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request):
+ """Init weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group)
+
+
+@app.post("/destroy_weights_update_group")
+async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request):
+ """Destroy weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group)
+
+
+@app.post("/update_weights_from_distributed")
+async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed)
+
+
+@app.post("/update_weights_from_tensor")
+async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor)
+
+
+@app.post("/update_weights_from_ipc")
+async def update_weights_from_ipc(request: UpdateWeightsFromIPCReq, raw_request: Request):
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_ipc)
+
+
+@app.post("/flush_cache")
+@app.get("/flush_cache")
+async def flush_cache():
+ """Flush the radix cache."""
+ return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache)
+
+
+@app.post("/pause_generation")
+async def pause_generation():
+ await g_objs.httpserver_manager.pause_generation()
+ return Response(content="Generation paused successfully.", status_code=200)
+
+
+@app.post("/continue_generation")
+async def continue_generation():
+ await g_objs.httpserver_manager.continue_generation()
+ return Response(content="Generation continued successfully.", status_code=200)
+
+
+@app.get("/release_memory_occupation")
+@app.post("/release_memory_occupation")
+async def release_memory_occupation(request: ReleaseMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation)
+
+
+@app.get("/resume_memory_occupation")
+@app.post("/resume_memory_occupation")
+async def resume_memory_occupation(request: ResumeMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation)
+
+
@app.websocket("/pd_register")
async def register_and_keep_alive(websocket: WebSocket):
await websocket.accept()
diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py
index 39a5808aab..28d57ccdc4 100644
--- a/lightllm/server/api_lightllm.py
+++ b/lightllm/server/api_lightllm.py
@@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
return_details = sample_params_dict.pop("return_details", False)
+ return_routed_experts = sample_params_dict.pop(
+ "return_routed_experts", httpserver_manager.args.enable_return_routed_experts
+ )
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
@@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt_token_ids = None
is_first_metadata = True
input_usage = None
+ routed_experts_data = None
async for sub_req_id, request_output, metadata, finish_status in results_generator:
# when set "--return_all_prompt_logprobs", the first token metadata will contains
# prompt_logprobs and prompt_token_ids
@@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
if finish_status.is_finished():
finish_reason_dict[sub_req_id] = finish_status
+ if "routed_experts" in metadata:
+ routed_experts_data = metadata["routed_experts"]
n = sampling_params.n
sub_ids = list(final_output_dict.keys())[:n]
final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids]
@@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
ret["prompt_logprobs"] = prompt_logprobs
if input_usage is not None:
ret["input_usage"] = input_usage
+ if return_routed_experts and routed_experts_data is not None:
+ ret["routed_experts"] = routed_experts_data
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
@@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
_ = sample_params_dict.pop("return_details", False)
+ _ = sample_params_dict.pop("return_routed_experts", None)
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py
index 6e04d5d47e..5306ecb698 100755
--- a/lightllm/server/api_server.py
+++ b/lightllm/server/api_server.py
@@ -1,11 +1,22 @@
import torch
-from .api_cli import make_argument_parser
+from .api_cli import add_cli_args
+from lightllm.server.core.objs.start_args_type import StartArgs
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
-if __name__ == "__main__":
- torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
- parser = make_argument_parser()
- args = parser.parse_args()
- from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start
+
+def launch_server(args: StartArgs):
+ from .api_start import pd_master_start, normal_or_p_d_start, config_server_start, visual_only_start
+
+ try:
+ # this code will not be ok for settings to fork to subprocess
+ torch.multiprocessing.set_start_method("spawn")
+ except RuntimeError as e:
+ logger.warning(f"Failed to set start method: {e}")
+ except Exception as e:
+ logger.error(f"Failed to set start method: {e}")
+ raise e
if args.run_mode == "pd_master":
pd_master_start(args)
@@ -15,3 +26,13 @@
visual_only_start(args)
else:
normal_or_p_d_start(args)
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser()
+ add_cli_args(parser)
+ args = parser.parse_args()
+
+ launch_server(StartArgs(**vars(args)))
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 3cf431d650..42d32c7ec3 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -1,3 +1,4 @@
+import multiprocessing as mp
import os
import sys
import time
@@ -18,6 +19,7 @@
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.redis_utils import start_redis_service
from lightllm.utils.shm_size_check import check_recommended_shm_size
+from lightllm.server.core.objs.start_args_type import StartArgs
from lightllm.utils.config_utils import (
has_audio_module,
has_vision_module,
@@ -60,9 +62,31 @@ def signal_handler(sig, frame):
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)
+ elif sig == signal.SIGHUP:
+ logger.info("Received SIGHUP (terminal closed), shutting down gracefully...")
+ if http_server_process and http_server_process.poll() is None:
+ http_server_process.send_signal(signal.SIGTERM)
+
+ start_time = time.time()
+ while (time.time() - start_time) < 60:
+ if not is_process_active(http_server_process.pid):
+ logger.info("httpserver exit")
+ break
+ time.sleep(1)
+
+ if time.time() - start_time < 60:
+ logger.info("HTTP server has exited gracefully")
+ else:
+ logger.warning("HTTP server did not exit in time, killing it...")
+ kill_recursive(http_server_process)
+
+ process_manager.terminate_all_processes()
+ logger.info("All processes have been terminated gracefully due to terminal closure.")
+ sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGHUP, signal_handler)
logger.info(f"start process pid {os.getpid()}")
if http_server_process:
@@ -70,10 +94,13 @@ def signal_handler(sig, frame):
return
-def normal_or_p_d_start(args):
- from lightllm.server.core.objs.start_args_type import StartArgs
+def _set_envs_and_config(args: StartArgs):
+ mp.set_start_method("spawn", force=True)
- args: StartArgs = args
+
+def _launch_subprocesses(args: StartArgs):
+
+ _set_envs_and_config(args)
auto_set_max_req_total_len(args)
set_unique_server_name(args)
@@ -143,12 +170,6 @@ def normal_or_p_d_start(args):
check_recommended_shm_size(args)
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
- # 确保单机上多实列不冲突
- if args.zmq_mode == "ipc:///tmp/":
- zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_"
- args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
- args.zmq_mode = zmq_mode
- logger.info(f"zmq mode head: {args.zmq_mode}")
logger.info(f"use tgi api: {args.use_tgi_api}")
@@ -210,12 +231,16 @@ def normal_or_p_d_start(args):
# mtp params check
if args.mtp_mode is not None:
- assert args.mtp_draft_model_dir is not None
+ if args.mtp_draft_model_dir is None:
+ args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step
assert args.mtp_step > 0
else:
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0
+ # automatically set visual_dp based on visual_tp and tp
+ if args.visual_tp < args.tp and args.tp % args.visual_tp == 0:
+ args.visual_dp = args.tp // args.visual_tp
if args.afs_image_embed_dir is not None:
os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(args.afs_image_embed_dir, 0o777)
@@ -336,8 +361,12 @@ def normal_or_p_d_start(args):
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
already_uesd_ports = [args.port]
- if args.nccl_port is not None:
+ # nccl_port 只在 rank 0 上 bind(TCPStore listener),其他 rank 是 connect,
+ # 不应该把它加入端口锁定列表,否则单机多节点 tp 测试会冲突。
+ if args.nccl_port is not None and args.node_rank == 0:
already_uesd_ports.append(args.nccl_port)
+ if args.rl_rpyc_port is not None:
+ already_uesd_ports.append(args.rl_rpyc_port)
if args.visual_nccl_ports is not None:
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])
if not args.disable_audio and args.audio_nccl_ports is not None:
@@ -350,9 +379,12 @@ def normal_or_p_d_start(args):
node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
- num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp,
+ num=11 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp,
+ instance_id=args.lightllm_instance_id,
used_ports=already_uesd_ports,
)
+ auto_ports_locker = PortLocker(can_use_ports)
+ auto_ports_locker.lock_port()
logger.info(f"alloced ports: {can_use_ports}")
(
nccl_port,
@@ -365,8 +397,9 @@ def normal_or_p_d_start(args):
cache_port,
metric_port,
multi_level_kv_cache_port,
- ) = can_use_ports[0:10]
- can_use_ports = can_use_ports[10:]
+ rl_rpyc_port,
+ ) = can_use_ports[0:11]
+ can_use_ports = can_use_ports[11:]
if args.visual_nccl_ports is None:
args.visual_nccl_ports = can_use_ports[: args.visual_dp]
@@ -383,6 +416,18 @@ def normal_or_p_d_start(args):
# 将申请好的端口放入args参数中
if args.nccl_port is None:
args.nccl_port = nccl_port
+ if args.rl_rpyc_port is None:
+ args.rl_rpyc_port = rl_rpyc_port
+
+ set_unique_server_name(args)
+
+ # 确保单机上多实列不冲突
+ if args.zmq_mode == "ipc:///tmp/":
+ zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_"
+ args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
+ args.zmq_mode = zmq_mode
+ logger.info(f"zmq mode head: {args.zmq_mode}")
+
args.router_port = router_port
args.router_profiler_port = router_profiler_port
args.detokenization_port = detokenization_port
@@ -415,6 +460,7 @@ def normal_or_p_d_start(args):
logger.info(f"all start args:{args}")
ports_locker.release_port()
+ auto_ports_locker.release_port()
if args.enable_multimodal:
process_manager.start_submodule_processes(
@@ -486,6 +532,13 @@ def normal_or_p_d_start(args):
],
)
+ return process_manager
+
+
+def normal_or_p_d_start(args: StartArgs):
+
+ process_manager = _launch_subprocesses(args)
+
# 启动 Hypercorn
command = [
"hypercorn",
@@ -521,7 +574,7 @@ def normal_or_p_d_start(args):
return
-def pd_master_start(args):
+def pd_master_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "pd_master":
return
@@ -540,16 +593,16 @@ def pd_master_start(args):
logger.info(f"all start args:{args}")
can_use_ports = alloc_can_use_network_port(
- num=1,
- used_ports=[
- args.port,
- ],
+ num=1, used_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id
)
+ auto_ports_locker = PortLocker(can_use_ports)
+ auto_ports_locker.lock_port()
metric_port = can_use_ports[0]
args.metric_port = metric_port
set_env_start_args(args)
+ auto_ports_locker.release_port()
process_manager.start_submodule_processes(
start_funcs=[
@@ -599,7 +652,10 @@ def visual_only_start(args):
can_use_ports = alloc_can_use_network_port(
num=5 + args.visual_dp * args.visual_tp + args.visual_dp,
used_ports=already_uesd_ports,
+ instance_id=args.lightllm_instance_id,
)
+ auto_ports_locker = PortLocker(can_use_ports)
+ auto_ports_locker.lock_port()
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
@@ -620,6 +676,7 @@ def visual_only_start(args):
logger.info(f"all start args:{args}")
set_env_start_args(args)
+ auto_ports_locker.release_port()
from .visualserver.visual_only_manager import start_visual_process
diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py
index cbc63c898d..2514d9dacb 100644
--- a/lightllm/server/core/objs/py_sampling_params.py
+++ b/lightllm/server/core/objs/py_sampling_params.py
@@ -111,13 +111,18 @@ def __init__(
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
- cls._do_sample = generation_cfg.get("do_sample", False)
- cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
- cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
- cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
- cls._temperature = generation_cfg.get("temperature", 1.0)
- cls._top_p = generation_cfg.get("top_p", 1.0)
- cls._top_k = generation_cfg.get("top_k", -1)
+
+ def _cfg(key, default):
+ v = generation_cfg.get(key)
+ return v if v is not None else default
+
+ cls._do_sample = _cfg("do_sample", False)
+ cls._presence_penalty = _cfg("presence_penalty", 0.0)
+ cls._frequency_penalty = _cfg("frequency_penalty", 0.0)
+ cls._repetition_penalty = _cfg("repetition_penalty", 1.0)
+ cls._temperature = _cfg("temperature", 1.0)
+ cls._top_p = _cfg("top_p", 1.0)
+ cls._top_k = _cfg("top_k", -1)
cls._stop_sequences = generation_cfg.get("stop", None)
except:
pass
diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py
index 7f2b697091..90fede4cf4 100644
--- a/lightllm/server/core/objs/req.py
+++ b/lightllm/server/core/objs/req.py
@@ -3,6 +3,7 @@
import ctypes
import numpy as np
import time
+from multiprocessing import shared_memory
from .sampling_params import SamplingParams
from .out_token_circlequeue import CircularQueue
from .shm_array import ShmArray
@@ -14,6 +15,7 @@
from lightllm.utils.kv_cache_utils import compute_token_list_hash
from typing import List, Any, Union
from lightllm.utils.log_utils import init_logger
+from lightllm.utils.shm_utils import create_or_link_shm
logger = init_logger(__name__)
@@ -25,19 +27,20 @@ class FinishStatus(ctypes.Structure):
NO_FINISH = 0
FINISHED_STOP = 1
FINISHED_LENGTH = 2
+ FINISHED_ABORTED = 3
def __init__(self, init_state=NO_FINISH):
self.status = init_state
def set_status(self, new_status):
- assert 0 <= new_status <= 2
+ assert 0 <= new_status <= 3
self.status = new_status
def get_status(self):
return self.status
def is_finished(self):
- return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
+ return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED
def is_stopped(self):
return self.status == self.FINISHED_STOP
@@ -50,6 +53,8 @@ def get_finish_reason(self):
return "stop"
elif self.status == self.FINISHED_LENGTH:
return "length"
+ elif self.status == self.FINISHED_ABORTED:
+ return "abort"
return None
@@ -277,6 +282,43 @@ def link_logprobs_shm_array(self):
self.shm_logprobs.link_shm()
return
+ def get_routing_data_shm_name(self):
+ service_uni_name = get_unique_server_name()
+ return f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
+
+ def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.uint8):
+ """Create routing SHM at actual size."""
+ self.shm_routing_data = ShmArray(
+ self.get_routing_data_shm_name(), (num_tokens, num_moe_layers, topk), dtype=np_dtype
+ )
+ self.shm_routing_data.create_shm()
+ return
+
+ def link_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.uint8):
+ """Link routing SHM at actual size."""
+ if num_moe_layers <= 0 or num_tokens <= 0 or topk <= 0:
+ return None
+ shm_routing_data = ShmArray(
+ self.get_routing_data_shm_name(), (num_tokens, num_moe_layers, topk), dtype=np_dtype
+ )
+ shm_routing_data.link_shm()
+ self.shm_routing_data = shm_routing_data
+ return self.shm_routing_data.arr
+
+ def close_routing_data_shm_array(self):
+ """Close and unlink routing SHM (on-demand, no longer pooled)."""
+ if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None:
+ self.shm_routing_data.close_shm()
+ self.shm_routing_data = None
+ else:
+ try:
+ shm = shared_memory.SharedMemory(name=self.get_routing_data_shm_name())
+ shm.close()
+ shm.unlink()
+ except FileNotFoundError:
+ pass
+ return
+
def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()
@@ -297,9 +339,8 @@ def can_release(self):
ref_count_ok = self.ref_count == 1
can_released_mark = self.can_released_mark
- if self.is_aborted and can_released_mark and ref_count_ok:
- return True
-
+ # if self.is_aborted and can_released_mark and ref_count_ok:
+ # return True
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched
if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index c39559f5f6..9cb52f02b7 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -396,15 +396,18 @@ def init(self, tokenizer, **kwargs):
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
- cls._do_sample = generation_cfg.get("do_sample", False)
- cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
- cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
- cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
- if cls._repetition_penalty is None:
- cls._repetition_penalty = 1.0
- cls._temperature = generation_cfg.get("temperature", 1.0)
- cls._top_p = generation_cfg.get("top_p", 1.0)
- cls._top_k = generation_cfg.get("top_k", -1)
+
+ def _cfg(key, default):
+ v = generation_cfg.get(key)
+ return v if v is not None else default
+
+ cls._do_sample = _cfg("do_sample", False)
+ cls._presence_penalty = _cfg("presence_penalty", 0.0)
+ cls._frequency_penalty = _cfg("frequency_penalty", 0.0)
+ cls._repetition_penalty = _cfg("repetition_penalty", 1.0)
+ cls._temperature = _cfg("temperature", 1.0)
+ cls._top_p = _cfg("top_p", 1.0)
+ cls._top_k = _cfg("top_k", -1)
except:
pass
diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py
index c5ad512c6b..74d64b6c5e 100644
--- a/lightllm/server/core/objs/shm_array.py
+++ b/lightllm/server/core/objs/shm_array.py
@@ -26,6 +26,13 @@ def link_shm(self):
self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf)
return
+ def detach_shm(self):
+ """Close handle without unlinking (SHM persists for reuse)."""
+ if self.shm is not None:
+ self.shm.close()
+ self.shm = None
+ self.arr = None
+
def close_shm(self):
if self.shm is not None:
self.shm.close()
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 40c8028158..f6eacbfc8b 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
-# 只是为了更好的编程提示
+# 服务启动参数
@dataclass
@@ -10,31 +10,50 @@ class StartArgs:
default="normal",
metadata={"choices": ["normal", "pd_master", "prefill", "decode", "config_server", "visual_only"]},
)
+ performance_mode: str = field(default=None, metadata={"choices": ["personal"]})
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
+ httpserver_workers: int = field(default=1)
zmq_mode: str = field(
default="ipc:///tmp/",
metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"},
)
- pd_master_ip: str = field(default="127.0.0.1")
+ pd_master_ip: str = field(default="0.0.0.0")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
config_server_visual_redis_port: int = field(default=None)
afs_image_embed_dir: str = field(default=None)
afs_embed_capacity: int = field(default=250000)
- select_p_d_node_strategy: str = field(default=None)
+ rl_rpyc_port: int = field(default=None)
+ select_p_d_node_strategy: str = field(
+ default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]}
+ )
model_name: str = field(default="default_model_name")
+ model_owner: Optional[str] = field(default=None)
model_dir: Optional[str] = field(default=None)
- tokenizer_mode: str = field(default="slow")
+ tokenizer_mode: str = field(default="fast")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.8)
batch_max_tokens: Optional[int] = field(default=None)
- eos_id: List[int] = field(default_factory=list)
+ eos_id: Optional[List[int]] = field(default=None)
tool_call_parser: Optional[str] = field(
default=None,
- metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]},
+ metadata={
+ "choices": [
+ "qwen25",
+ "llama3",
+ "mistral",
+ "deepseekv3",
+ "qwen",
+ "deepseekv31",
+ "deepseekv32",
+ "glm47",
+ "kimi_k2",
+ "qwen3_coder",
+ ]
+ },
)
reasoning_parser: Optional[str] = field(
default=None,
@@ -53,11 +72,12 @@ class StartArgs:
"step3",
"nano_v3",
"interns1",
+ "gemma4",
]
},
)
chat_template: Optional[str] = field(default=None)
- running_max_req_size: int = field(default=512)
+ running_max_req_size: int = field(default=256)
tp: int = field(default=1)
dp: int = field(default=1)
nnodes: int = field(default=1)
@@ -66,12 +86,13 @@ class StartArgs:
max_req_total_len: Optional[int] = field(default=None)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=None)
+ lightllm_instance_id: int = field(default=0)
use_config_server_to_init_nccl: bool = field(default=False)
trust_remote_code: bool = field(default=False)
detail_log: bool = field(default=False)
disable_log_stats: bool = field(default=False)
log_stats_interval: int = field(default=10)
- router_token_ratio: float = field(default=0.0)
+ router_token_ratio: float = field(default=None)
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
enable_prefill_decode_mixed: bool = field(default=False)
@@ -80,7 +101,7 @@ class StartArgs:
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
- output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
+ output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
disable_vision: Optional[bool] = field(default=None)
@@ -109,12 +130,12 @@ class StartArgs:
)
metric_gateway: Optional[str] = field(default=None)
job_name: str = field(default="lightllm")
- grouping_key: List[str] = field(default_factory=list)
+ grouping_key: List[str] = field(default_factory=lambda: [])
push_interval: int = field(default=10)
visual_node_id: int = field(default=None)
visual_infer_batch_size: int = field(default=None)
visual_send_batch_size: int = field(default=1)
- visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
+ visual_gpu_ids: List[int] = field(default=None)
visual_tp: int = field(default=1)
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default=None)
@@ -132,19 +153,19 @@ class StartArgs:
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
graph_max_len_in_batch: int = field(default=0)
- quant_type: Optional[str] = field(default=None)
+ quant_type: Optional[str] = field(default="none")
quant_cfg: Optional[str] = field(default=None)
- expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]})
- vit_quant_type: Optional[str] = field(default=None)
+ vit_quant_type: Optional[str] = field(default="none")
vit_quant_cfg: Optional[str] = field(default=None)
+ expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]})
llm_prefill_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
)
llm_decode_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
)
vit_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]}
)
llm_kv_type: str = field(
default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]}
@@ -165,8 +186,6 @@ class StartArgs:
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
- "qwen3next_vanilla",
- "qwen3next_eagle",
None,
]
},
@@ -179,7 +198,7 @@ class StartArgs:
pd_node_id: int = field(default=-1)
enable_cpu_cache: bool = field(default=False)
cpu_cache_storage_size: float = field(default=2)
- cpu_cache_token_page_size: int = field(default=64)
+ cpu_cache_token_page_size: int = field(default=256)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disk_cache_dir: Optional[str] = field(default=None)
@@ -195,6 +214,27 @@ class StartArgs:
metric_port: int = field(default=None)
multinode_httpmanager_port: int = field(default=12345)
multi_level_kv_cache_port: int = field(default=None)
+ # multi_modal
+ enable_multimodal_audio: bool = field(default=False)
+
+ disable_shm_warning: bool = field(default=False)
+ dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]})
+ enable_custom_allgather: bool = field(default=False)
+ enable_fused_shared_experts: bool = field(default=False)
+ enable_mps: bool = field(default=False)
+ multinode_router_gloo_port: int = field(default=20001)
+ schedule_time_interval: float = field(default=0.03)
+ use_dynamic_prompt_cache: bool = field(default=False)
+ disable_custom_allreduce: bool = field(default=False)
+ enable_torch_memory_saver: bool = field(default=False)
+ enable_weight_cpu_backup: bool = field(default=False)
+ hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]})
+ enable_torch_fallback: bool = field(default=False)
+ enable_triton_fallback: bool = field(default=False)
+
+ enable_return_routed_experts: bool = field(default=False)
+
+ weight_version: str = "default"
# hybrid attention model (Qwen3Next)
linear_att_hash_page_size: int = field(default=512)
diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py
index 9aa3a8effc..c77379986c 100644
--- a/lightllm/server/detokenization/decode_req.py
+++ b/lightllm/server/detokenization/decode_req.py
@@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool:
return False
def need_detoken(self):
- if (
- (not self.req.is_aborted)
- and (not self.req.stop_str_matched)
- and len(self.output_ids) < self.req.candetoken_out_len
- ):
+ if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len:
return True
return False
@@ -83,8 +79,6 @@ def get_decode_tokens(self):
return prefix_tokens, read_tokens
def can_set_release_mark(self):
- if self.req.is_aborted:
- return True
if self.req.stop_str_matched:
return True
if (
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 0f1b873111..2b1c0a0afa 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -3,13 +3,14 @@
import zmq.asyncio
import asyncio
import uvloop
-import rpyc
import socket
+import rpyc
import time
import copy
import hashlib
import datetime
import pickle
+import base64
from frozendict import frozendict
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -29,8 +30,22 @@
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
+from lightllm.common.basemodel.routing_manager import get_routing_config_from_model_dir, routing_dtype_id_to_np
from lightllm.utils.log_utils import init_logger
from lightllm.server.metrics.manager import MetricClient
+from lightllm.server.io_struct import (
+ AbortReq,
+ FlushCacheReq,
+ ReleaseMemoryReq,
+ ResumeMemoryReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+ UpdateWeightsFromIPCReq,
+ GeneralModelToHttpRpcRsp,
+)
+from .rl_controller import HttpRlController
from lightllm.utils.statics_utils import MovingAverage
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
@@ -56,7 +71,6 @@ def __init__(
self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0))
self._run_reqs_count_lock = AsyncLock(self._shm_lock_pool.get_lock_context(1))
self.node_rank = args.node_rank
- self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
self.is_multinode_tp_master = args.dp == 1 and args.nnodes > 1 and args.node_rank == 0
self.is_multinode_tp_slave = args.dp == 1 and args.nnodes > 1 and args.node_rank > 0
@@ -75,7 +89,7 @@ def __init__(
self.multinode_req_manager = context.socket(zmq.PULL)
self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}")
logger.info(
- f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}"
+ f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}"
)
self.enable_multimodal = args.enable_multimodal
@@ -122,6 +136,13 @@ def __init__(
# Timemark of the latest successful inference, used by passive /health checks.
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
self.latest_success_infer_time_mark.set_value(int(time.time()))
+ self._routing_config = (
+ get_routing_config_from_model_dir(args.model_dir)
+ if args.enable_return_routed_experts and args.node_rank == 0
+ else None
+ )
+
+ self.rl_controller = HttpRlController(self)
self.run_reqs_count_mark = SharedInt(f"{get_unique_server_name()}_run_reqs_count_mark")
self.run_reqs_count_mark.set_value(0)
@@ -335,6 +356,11 @@ async def generate(
self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() + 1)
try:
+ if await self.rl_controller.wait_until_generation_allowed(group_request_id):
+ for output in self._build_aborted_generation_outputs(group_request_id, sampling_params):
+ yield output
+ return
+
original_multimodal_params = None
if self.is_multinode_tp_master:
original_multimodal_params = copy.deepcopy(multimodal_params)
@@ -349,6 +375,7 @@ async def generate(
# 记录请求到达的相关信息
await self._log_req_header(request_headers, group_request_id)
+
# encode
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
self._log_stage_timing(
@@ -486,6 +513,15 @@ async def generate(
self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() - 1)
return
+ def _build_aborted_generation_outputs(self, group_request_id: int, sampling_params: SamplingParams):
+ # Paused requests never enter router queues, so synthesize an aborted empty
+ # result that both stream and non-stream APIs can consume normally.
+ finish_status = FinishStatus()
+ finish_status.set_status(FinishStatus.FINISHED_ABORTED)
+ metadata = {"prompt_tokens": 0, "count_output_tokens": 0}
+ for i in range(sampling_params.n):
+ yield group_request_id + i, "", metadata, finish_status
+
def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]:
image_tokens = 0
audio_tokens = 0
@@ -500,6 +536,43 @@ def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple
return image_tokens, audio_tokens
+ def _read_routed_experts_metadata(self, req: Req):
+ if self._routing_config is None:
+ return None
+ num_moe_layers, topk, dtype_id = self._routing_config
+ num_tokens = req.input_len + req.shm_cur_output_len
+ if num_tokens <= 0:
+ return None
+
+ try:
+ routing_data = req.link_routing_data_shm_array(
+ num_moe_layers, num_tokens, topk, np_dtype=routing_dtype_id_to_np(dtype_id)
+ )
+ if routing_data is None:
+ return None
+ return {
+ "shape": list(routing_data.shape),
+ "dtype": str(routing_data.dtype),
+ "data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
+ }
+ except Exception as e:
+ logger.warning(f"Failed to read routing data for req {req.request_id}: {e}")
+ return None
+
+ async def _wait_and_read_routed_experts_metadata(self, req: Req, timeout: float = 60.0):
+ if not await self.rl_controller.wait_until_can_released_mark(req, timeout=timeout):
+ return None
+
+ start_time = time.time()
+ while True:
+ routing_meta = self._read_routed_experts_metadata(req)
+ if routing_meta is not None:
+ return routing_meta
+ if time.time() - start_time > timeout:
+ logger.warning(f"wait routing data shm timeout, req_id={req.request_id}, timeout={timeout}s")
+ return None
+ await asyncio.sleep(0.005)
+
async def _log_req_header(self, request_headers, group_request_id: int):
x_request_id = request_headers.get("X-Request-Id", "")
@@ -562,6 +635,18 @@ async def _encode(
else:
raise ValueError("prompt List[int] format contain id > vocab_size")
else:
+ if self.enable_multimodal and self.pd_mode.is_P_or_NORMAL():
+ assert (
+ len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
+ ), "too many multimodal items!"
+ if multimodal_params.audios:
+ assert not self.args.disable_audio, "audio multimodal not enabled"
+ await self._alloc_multimodal_resources(multimodal_params, sampling_params)
+ return self.tokenizer.encode(
+ prompt,
+ multimodal_params,
+ add_special_tokens=sampling_params.add_special_tokens,
+ )
return prompt
else:
raise ValueError(f"prompt format error, get type{type(prompt)}")
@@ -688,7 +773,7 @@ async def _wait_to_token_package(
if req_status.aborted:
raise Exception(f"req_id {group_request_id} aborted notifyed by other module")
- if not self.disable_abort and request is not None and await request.is_disconnected():
+ if request is not None and await request.is_disconnected():
await self.abort(group_request_id)
raise ClientDisconnected(
group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected"
@@ -800,6 +885,9 @@ async def abort(self, group_req_id: int) -> bool:
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
return True
+ async def abort_request(self, request: AbortReq) -> Tuple[bool, str]:
+ return await self.rl_controller.abort_request(request)
+
def _get_router_profiler_client(self):
router_profiler_client = getattr(self, "router_profiler_client", None)
if router_profiler_client is None or getattr(router_profiler_client, "closed", False):
@@ -839,6 +927,10 @@ async def recycle_resource_loop(self):
self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None)
_is_aborted = False
for req in req_status.group_req_objs.shm_req_objs:
+ try:
+ req.close_routing_data_shm_array()
+ except Exception as e:
+ logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}")
_is_aborted = _is_aborted or req.is_aborted
logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}")
await self.shm_req_manager.async_put_back_req_obj(req)
@@ -931,6 +1023,11 @@ async def handle_loop(self):
else:
finish_status = FinishStatus(req.finish_status.status)
+ if self._routing_config is not None:
+ routing_meta = await self._wait_and_read_routed_experts_metadata(req)
+ if routing_meta is not None:
+ metadata["routed_experts"] = routing_meta
+
token_list.append((req_id, text, metadata, finish_status))
else:
break
@@ -945,6 +1042,36 @@ async def handle_loop(self):
self.recycle_event.set()
return
+ async def pause_generation(self):
+ return await self.rl_controller.pause_generation()
+
+ async def continue_generation(self):
+ return await self.rl_controller.continue_generation()
+
+ async def flush_cache(self, request: FlushCacheReq):
+ return await self.rl_controller.flush_cache(request)
+
+ async def release_memory_occupation(self, request: ReleaseMemoryReq):
+ return await self.rl_controller.release_memory_occupation(request)
+
+ async def resume_memory_occupation(self, request: ResumeMemoryReq):
+ return await self.rl_controller.resume_memory_occupation(request)
+
+ async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ return await self.rl_controller.init_weights_update_group(request)
+
+ async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ return await self.rl_controller.destroy_weights_update_group(request)
+
+ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+ return await self.rl_controller.update_weights_from_distributed(request)
+
+ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> GeneralModelToHttpRpcRsp:
+ return await self.rl_controller.update_weights_from_tensor(request)
+
+ async def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq) -> GeneralModelToHttpRpcRsp:
+ return await self.rl_controller.update_weights_from_ipc(request)
+
class ReqStatus:
def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None:
diff --git a/lightllm/server/httpserver/rl_controller.py b/lightllm/server/httpserver/rl_controller.py
new file mode 100644
index 0000000000..e40ccc916c
--- /dev/null
+++ b/lightllm/server/httpserver/rl_controller.py
@@ -0,0 +1,282 @@
+import asyncio
+import socket
+import time
+from contextlib import asynccontextmanager
+from typing import Optional, Tuple
+
+import rpyc
+from rpyc.utils.classic import obtain
+
+from lightllm.server.io_struct import (
+ AbortReq,
+ FlushCacheReq,
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ ReleaseMemoryReq,
+ ResumeMemoryReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromIPCReq,
+ UpdateWeightsFromTensorReq,
+)
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class _GenerationPauseGate:
+ """Generation pause gate.
+
+ Pending requests are tracked only while they are waiting at the pause gate.
+ Once generation is allowed, the request leaves this lightweight path and
+ continues as a normal request managed by req_id_to_out_inf.
+ """
+
+ _RUNNING = 0
+ _PAUSED = 1
+
+ def __init__(self) -> None:
+ self._state = self._RUNNING
+ self._pending_request_abort_events = {}
+ self._cond = asyncio.Condition()
+
+ @asynccontextmanager
+ async def pause_and_abort_context(self):
+ async with self._cond:
+ if self._state != self._RUNNING:
+ do_abort = False
+ else:
+ self._state = self._PAUSED
+ do_abort = True
+ yield do_abort
+
+ async def register_pending_request(self, request_id: int):
+ async with self._cond:
+ self._pending_request_abort_events[request_id] = asyncio.Event()
+
+ async def unregister_pending_request(self, request_id: int):
+ async with self._cond:
+ self._pending_request_abort_events.pop(request_id, None)
+
+ async def wait_until_running_or_aborted(self, request_id: int) -> bool:
+ """Returns True when the pending request should finish as aborted."""
+ async with self._cond:
+ abort_event = self._pending_request_abort_events.get(request_id)
+ if abort_event is None:
+ return False
+ if abort_event.is_set():
+ return True
+ if self._state == self._RUNNING:
+ return False
+
+ resume_task = asyncio.create_task(self._wait_until_running())
+ abort_task = asyncio.create_task(abort_event.wait())
+ done, pending = await asyncio.wait({resume_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
+ for task in pending:
+ task.cancel()
+ if pending:
+ await asyncio.gather(*pending, return_exceptions=True)
+ return abort_task in done and abort_event.is_set()
+
+ async def abort_pending_request(self, request_id: int) -> bool:
+ async with self._cond:
+ abort_event = self._pending_request_abort_events.get(request_id)
+ if abort_event is None:
+ return False
+ abort_event.set()
+ return True
+
+ async def abort_all_pending_requests(self):
+ async with self._cond:
+ for abort_event in self._pending_request_abort_events.values():
+ abort_event.set()
+
+ async def is_pending_request(self, request_id: int) -> bool:
+ async with self._cond:
+ return request_id in self._pending_request_abort_events
+
+ async def get_pending_request_count(self) -> int:
+ async with self._cond:
+ return len(self._pending_request_abort_events)
+
+ async def _wait_until_running(self) -> None:
+ async with self._cond:
+ while self._state != self._RUNNING:
+ await self._cond.wait()
+
+ async def resume(self) -> None:
+ async with self._cond:
+ self._state = self._RUNNING
+ self._cond.notify_all()
+
+
+class HttpRlController:
+ def __init__(self, manager) -> None:
+ self.manager = manager
+ self.args = manager.args
+ self._generation_gate = _GenerationPauseGate()
+
+ async def wait_until_generation_allowed(self, request_id: int) -> bool:
+ await self._generation_gate.register_pending_request(request_id)
+ try:
+ return await self._generation_gate.wait_until_running_or_aborted(request_id)
+ finally:
+ await self._generation_gate.unregister_pending_request(request_id)
+
+ async def wait_until_can_released_mark(self, req, timeout: float = 60.0) -> bool:
+ start_time = time.time()
+ while not req.can_released_mark:
+ if time.time() - start_time > timeout:
+ logger.warning(f"wait req can_released_mark timeout, req_id={req.request_id}, timeout={timeout}s")
+ return False
+ await asyncio.sleep(0.005)
+ return True
+
+ async def _wait_for_abort_released(
+ self, request_id: Optional[int], abort_all: bool, timeout: float = 60.0
+ ) -> Tuple[bool, str]:
+ start_time = time.time()
+ empty_since = None
+ while True:
+ if abort_all:
+ pending_request_count = await self._generation_gate.get_pending_request_count()
+ if len(self.manager.req_id_to_out_inf) == 0 and pending_request_count == 0:
+ empty_since = empty_since or time.time()
+ if time.time() - empty_since >= 1.0:
+ return True, ""
+ else:
+ empty_since = None
+ await self._generation_gate.abort_all_pending_requests()
+ for group_req_id, req_status in list(self.manager.req_id_to_out_inf.items()):
+ if req_status is not None and any(
+ not req.is_aborted for req in req_status.group_req_objs.shm_req_objs
+ ):
+ await self.manager.abort(group_req_id)
+ else:
+ req_status = self.manager.req_id_to_out_inf.get(request_id, None)
+ if req_status is not None:
+ if any(not req.is_aborted for req in req_status.group_req_objs.shm_req_objs):
+ await self.manager.abort(request_id)
+ elif not await self._generation_gate.is_pending_request(request_id):
+ return True, ""
+
+ if time.time() - start_time > timeout:
+ error_msg = (
+ f"abort request wait release timeout, request_id={request_id}, abort_all={abort_all}, "
+ f"timeout={timeout}s"
+ )
+ logger.error(error_msg)
+ return False, error_msg
+
+ await asyncio.sleep(0.02)
+ return True, ""
+
+ async def abort_request(self, request: AbortReq) -> Tuple[bool, str]:
+ request_id = request.request_id
+ if request.abort_all:
+ await self._generation_gate.abort_all_pending_requests()
+ for group_req_id in list(self.manager.req_id_to_out_inf.keys()):
+ await self.manager.abort(group_req_id)
+ return await self._wait_for_abort_released(request_id=None, abort_all=True)
+
+ if request_id is None:
+ return True, ""
+
+ await self._generation_gate.abort_pending_request(request_id)
+ await self.manager.abort(request_id)
+ return await self._wait_for_abort_released(request_id=request_id, abort_all=False)
+
+ async def pause_generation(self):
+ async with self._generation_gate.pause_and_abort_context() as do_abort:
+ if not do_abort:
+ return
+ while True:
+ success, msg = await self.abort_request(AbortReq(request_id=None, abort_all=True))
+ if success:
+ break
+ logger.warning(f"pause_generation abort_all still waiting: {msg}")
+ await asyncio.sleep(1.0)
+
+ async def continue_generation(self):
+ await self._generation_gate.resume()
+
+ def _call_router_rl_sync(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ from lightllm.utils.retry_utils import retry
+
+ conn = retry(max_attempts=20, wait_time=0.5)(rpyc.connect)(
+ "localhost",
+ self.args.rl_rpyc_port,
+ config={"allow_pickle": True, "sync_request_timeout": 600},
+ )
+ try:
+ conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ return obtain(conn.root.rl_op(req))
+ finally:
+ try:
+ conn.close()
+ except BaseException:
+ pass
+
+ async def _call_router_rl(self, func_name: str, func_args=None) -> GeneralModelToHttpRpcRsp:
+ req = GeneralHttpToModelRpcReq(func_name=func_name, func_args=func_args)
+ try:
+ return await asyncio.to_thread(self._call_router_rl_sync, req)
+ except BaseException as e:
+ logger.exception(f"rl rpyc call {func_name} failed: {e}")
+ return GeneralModelToHttpRpcRsp(
+ success=False,
+ msg=f"rl rpyc call {func_name} error: {e}",
+ func_name=func_name,
+ )
+
+ async def flush_cache(self, request: FlushCacheReq):
+ return await self._call_router_rl("flush_cache", request)
+
+ async def release_memory_occupation(self, request: ReleaseMemoryReq):
+ assert (
+ len(self.manager.req_id_to_out_inf) == 0
+ ), "there are still requests running, cannot release memory occupation"
+ return await self._call_router_rl("release_memory_occupation", request.tags)
+
+ async def resume_memory_occupation(self, request: ResumeMemoryReq):
+ return await self._call_router_rl("resume_memory_occupation", request.tags)
+
+ async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ return await self._call_router_rl("init_weights_update_group", request)
+
+ async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ return await self._call_router_rl("destroy_weights_update_group", request)
+
+ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+ if request.abort_all_requests:
+ success, msg = await self.abort_request(AbortReq(abort_all=True))
+ if not success:
+ return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_distributed")
+ if request.flush_cache:
+ ret = await self.flush_cache(FlushCacheReq())
+ if not ret.success:
+ return ret
+ return await self._call_router_rl("update_weights_from_distributed", request)
+
+ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> GeneralModelToHttpRpcRsp:
+ if request.abort_all_requests:
+ success, msg = await self.abort_request(AbortReq(abort_all=True))
+ if not success:
+ return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_tensor")
+ if request.flush_cache:
+ ret = await self.flush_cache(FlushCacheReq())
+ if not ret.success:
+ return ret
+ return await self._call_router_rl("update_weights_from_tensor", request)
+
+ async def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq) -> GeneralModelToHttpRpcRsp:
+ if request.abort_all_requests:
+ success, msg = await self.abort_request(AbortReq(abort_all=True))
+ if not success:
+ return GeneralModelToHttpRpcRsp(success=False, msg=msg, func_name="update_weights_from_ipc")
+ if request.flush_cache:
+ ret = await self.flush_cache(FlushCacheReq())
+ if not ret.success:
+ return ret
+ return await self._call_router_rl("update_weights_from_ipc", request)
diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py
new file mode 100644
index 0000000000..1ce42f25ed
--- /dev/null
+++ b/lightllm/server/io_struct.py
@@ -0,0 +1,101 @@
+from dataclasses import dataclass
+from typing import List, Optional, Any, Union
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+
+
+@dataclass
+class AbortReq:
+ # 外部调用传入,等同内部的 group_req_id
+ request_id: Optional[int] = None
+ abort_all: bool = False
+
+
+def _normalize_memory_tags(tags):
+ if tags is None:
+ return None
+ return [tag if isinstance(tag, MemoryTag) else MemoryTag(tag) for tag in tags]
+
+
+@dataclass
+class FlushCacheReq:
+ pass
+
+
+@dataclass
+class ReleaseMemoryReq:
+ tags: Optional[List[MemoryTag]] = None
+
+ def __post_init__(self):
+ self.tags = _normalize_memory_tags(self.tags)
+
+
+@dataclass
+class ResumeMemoryReq:
+ tags: Optional[List[MemoryTag]] = None
+
+ def __post_init__(self):
+ self.tags = _normalize_memory_tags(self.tags)
+
+
+@dataclass
+class GeneralHttpToModelRpcReq:
+ func_name: str
+ func_args: Optional[Any] = None
+
+
+@dataclass
+class GeneralModelToHttpRpcRsp:
+ success: bool
+ msg: Optional[str]
+ func_name: str
+ func_rsp: Optional[Any] = None
+
+
+@dataclass
+class InitWeightsUpdateGroupReq:
+ master_address: str
+ master_port: int
+ rank_offset: int
+ world_size: int
+ group_name: str = "weight_update_group"
+ backend: str = "nccl"
+
+
+@dataclass
+class DestroyWeightsUpdateGroupReq:
+ group_name: str = "weight_update_group"
+
+
+@dataclass
+class UpdateWeightsFromDistributedReq:
+ names: List[str]
+ dtypes: List[str]
+ shapes: List[List[int]]
+ group_name: str = "weight_update_group"
+ flush_cache: bool = True
+ abort_all_requests: bool = False
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromTensorReq:
+ """Update model weights from tensor input.
+
+ - Tensors are serialized for transmission
+ - Data is structured in JSON for easy transmission over HTTP
+ """
+
+ serialized_named_tensors: List[Union[str, bytes]]
+ load_format: Optional[str] = None
+ flush_cache: bool = True
+ abort_all_requests: bool = False
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromIPCReq:
+ ipc_handle: Optional[Union[str, dict]] = None
+ use_shm: bool = False
+ flush_cache: bool = True
+ abort_all_requests: bool = False
+ weight_version: Optional[str] = None
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index 9541e434c8..784762717d 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -77,9 +77,11 @@ def read(self):
assert self._preload_data is not None
ans = self._preload_data
self._preload_data = None
- self._data = None
return ans
+ def free(self):
+ self._data = None
+
def to_dict(self):
ret = {}
ret["uuid"] = self.uuid
diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py
index f7c099c292..971a3644cc 100644
--- a/lightllm/server/req_id_generator.py
+++ b/lightllm/server/req_id_generator.py
@@ -34,6 +34,9 @@ def __init__(self):
logger.info("ReqIDGenerator init finished")
def _wait_all_workers_ready(self):
+ if self.args.httpserver_workers == 1:
+ return
+
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs.shm_array import ShmArray
diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
index bf07e121e6..6a8e0a3917 100644
--- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
@@ -470,6 +470,10 @@ def clear_tree_nodes(self):
self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
return
+ def flush_cache(self):
+ self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
+ return
+
def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]:
assert not node.is_big_page_node()
iter_node = node
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index 21e26c5854..c103a61473 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -106,6 +106,7 @@ class RadixCache:
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
from lightllm.common.kv_cache_mem_manager import MemoryManager
+ self.total_token_num = total_token_num
self.mem_manager: MemoryManager = mem_manager
self._key_dtype = torch.int64
self._value_dtype = torch.int64
@@ -419,6 +420,10 @@ def clear_tree_nodes(self):
self.refed_tokens_num.arr[0] = 0
return
+ def flush_cache(self):
+ self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
+ return
+
def dec_node_ref_counter(self, node: TreeNode):
if node is None:
return
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index dfb8866601..5401189056 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -12,7 +12,7 @@
import torch.multiprocessing as mp
import torch.distributed as dist
import multiprocessing
-from typing import Dict, List, Optional
+from typing import List
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
@@ -29,6 +29,10 @@
from lightllm.utils.profiler import ProfilerCmd
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
+from lightllm.server.io_struct import (
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+)
from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
@@ -36,6 +40,7 @@
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from .stats import RouterStatics
from .profiler_service import RouterProfilerCmdQueue, start_router_profiler_server
+from .rl_rpyc import RouterRlOpQueue, start_router_rl_rpyc_server
logger = init_logger(__name__)
@@ -104,6 +109,9 @@ def __init__(self, args: StartArgs):
else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False)
)
self.router_statics = RouterStatics(self.args)
+
+ # RL 管理 rpyc 队列:rpyc 线程把请求放进来,asyncio 主循环每个 step 取出处理
+ self.rl_op_queue = RouterRlOpQueue()
self.profiler_cmd_queue = RouterProfilerCmdQueue()
return
@@ -349,11 +357,42 @@ def _filter_reqs_from_running_batch(self):
self.running_batch = None
return
+ def _sync_abort_req_ids_from_master(self, reqs: List[Req]):
+ local_aborted_req_ids = [req.request_id for req in reqs if req.is_aborted]
+ if not self.is_multinode_tp:
+ return set(local_aborted_req_ids)
+
+ # 多节点 TP 下 abort 只以 rank 0 httpserver 看到的状态为准。
+ if not self.is_multinode_tp_master:
+ local_aborted_req_ids = []
+
+ aborted_req_num = torch.tensor([len(local_aborted_req_ids)], dtype=torch.int64, device="cpu")
+ dist.broadcast(aborted_req_num, src=0, group=self.mulitnode_group)
+ aborted_req_num = int(aborted_req_num.item())
+ if aborted_req_num == 0:
+ return set()
+
+ if self.is_multinode_tp_master:
+ aborted_req_ids = torch.tensor(local_aborted_req_ids, dtype=torch.int64, device="cpu")
+ else:
+ aborted_req_ids = torch.empty(aborted_req_num, dtype=torch.int64, device="cpu")
+ dist.broadcast(aborted_req_ids, src=0, group=self.mulitnode_group)
+ return {int(req_id) for req_id in aborted_req_ids.tolist()}
+
+ def _mark_reqs_aborted(self, reqs: List[Req], aborted_req_ids):
+ if not aborted_req_ids:
+ return
+ for req in reqs:
+ if req.request_id in aborted_req_ids:
+ req.is_aborted = True
+ return
+
def _get_aborted_reqs_from_running_batch(self) -> List[Req]:
ans = []
- if self.running_batch is None:
- return ans
- for req in self.running_batch.reqs:
+ running_reqs = [] if self.running_batch is None else self.running_batch.reqs
+ aborted_req_ids = self._sync_abort_req_ids_from_master(running_reqs)
+ self._mark_reqs_aborted(running_reqs, aborted_req_ids)
+ for req in running_reqs:
if req.is_aborted and req._router_aborted is False:
req._router_aborted = True
ans.append(req)
@@ -453,11 +492,19 @@ def _multinode_tp_generate_new_batch(self):
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
req_id_select_mark = [1 for _ in range(len(req_ids))]
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
+ # select_mark 仍需 MIN allreduce: slave 是否已经从 httpserver 收到 generate 请求
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
+ aborted_req_ids = self._sync_abort_req_ids_from_master(new_batch.reqs)
back_req_list = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
- if select == 0:
- req = new_batch.pop_req(req_id)
+ for req_id, select in zip(req_ids, req_id_select_mark.tolist()):
+ is_aborted = req_id in aborted_req_ids
+ if select == 1 and not is_aborted:
+ continue
+ req = new_batch.pop_req(req_id)
+ if is_aborted and select == 1:
+ req.is_aborted = True
+ self.req_queue.release_aborted_req(req)
+ else:
back_req_list.append(req)
self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list
if new_batch.is_clear():
@@ -471,25 +518,30 @@ def _multinode_tp_generate_new_batch(self):
else:
req_ids = [None for _ in range(req_num)]
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
- all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
+ id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list}
req_id_select_mark = []
for req_id in req_ids:
- req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
+ req_id_select_mark.append(1 if req_id in id_to_req_obj else 0)
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
- select_req_ids = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
- if select == 1:
- select_req_ids.append(req_id)
-
+ aborted_req_ids = self._sync_abort_req_ids_from_master([])
select_reqs = []
- for req_id in select_req_ids:
- for req in self.req_queue.waiting_req_list:
- if req.request_id == req_id:
- select_reqs.append(req)
-
- for req in select_reqs:
- self.req_queue.waiting_req_list.remove(req)
+ aborted_reqs = []
+ for req_id, select in zip(req_ids, req_id_select_mark.tolist()):
+ if select == 1:
+ req = id_to_req_obj[req_id]
+ if req_id in aborted_req_ids:
+ req.is_aborted = True
+ aborted_reqs.append(req)
+ continue
+ select_reqs.append(req)
+ handled_req_ids = {req.request_id for req in select_reqs + aborted_reqs}
+ if handled_req_ids:
+ self.req_queue.waiting_req_list = [
+ req for req in self.req_queue.waiting_req_list if req.request_id not in handled_req_ids
+ ]
+ for req in aborted_reqs:
+ self.req_queue.release_aborted_req(req)
if select_reqs:
new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node)
else:
@@ -514,7 +566,7 @@ async def _recv_new_reqs_and_schedule(self):
if isinstance(recv_req, GroupReqIndexes):
self._add_req(recv_req)
else:
- assert False, f"Error Req Inf {recv_req}"
+ raise ValueError(f"Unknown request type: {type(recv_req)}")
# 当队列中存在较多的请求时,将一次接受的数量上调
self.recv_max_count = min(int(self.recv_max_count * 1.3), 256)
@@ -523,6 +575,8 @@ async def _recv_new_reqs_and_schedule(self):
# 当队列已经开始清空的时候,将一次接受的数量下调
self.recv_max_count = 64
+ await self._process_special_reqs()
+
if self.is_multinode_tp:
self._multinode_tp_generate_new_batch()
else:
@@ -530,6 +584,66 @@ async def _recv_new_reqs_and_schedule(self):
self._generate_new_batch()
return
+ async def _process_special_reqs(self):
+ # master: 从 RL rpyc 队列里取出 (req, future) — slave 的队列恒为空(无 rpyc service)
+ pairs = self.rl_op_queue.pop_all()
+
+ reqs: List[GeneralHttpToModelRpcReq] = [req for req, _ in pairs]
+
+ # 多机 TP:master 通过 NCCL 广播 req 到 slave router;slave 在自己的主循环里到达此处时,会从 broadcast 收到 master 的 reqs
+ if self.is_multinode_tp:
+ reqs = self.broadcast_reqs_to_other_nodes(reqs)
+
+ for i, req in enumerate(reqs):
+ assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq"
+ try:
+ ret = await self.forward_to_model(req)
+ except BaseException as e:
+ logger.exception(f"forward_to_model failed for {req.func_name}: {e}")
+ ret = GeneralModelToHttpRpcRsp(
+ success=False, msg=f"forward_to_model error: {e}", func_name=req.func_name
+ )
+ # 只有 master 持有 future,slave 的 pairs 始终为空
+ if i < len(pairs):
+ _, fut = pairs[i]
+ if not fut.done():
+ fut.set_result(ret)
+
+ def broadcast_reqs_to_other_nodes(self, reqs: List[GeneralHttpToModelRpcReq]):
+ req_num = len(reqs)
+ if self.node_rank == 0:
+ req_nums = [len(reqs)]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ else:
+ req_nums = [None]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ reqs = [None for _ in range(req_num)]
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ return reqs
+
+ async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ forward_to_model_tasks = []
+ for model_rpc_client in self.model_rpc_clients:
+ forward_to_model_tasks.append(model_rpc_client.forward_to_model(req))
+ all_ret = await asyncio.gather(*forward_to_model_tasks)
+ ret: GeneralModelToHttpRpcRsp = next((res for res in all_ret if not res.success), all_ret[0])
+ ret.success = all(res.success for res in all_ret)
+ if self.is_multinode_tp:
+ output_list = [None for _ in range(self.nnodes)] if self.node_rank == 0 else None
+ dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group)
+ if self.node_rank == 0:
+ for res in output_list:
+ res: GeneralModelToHttpRpcRsp
+ if not res.success:
+ ret = res
+ break
+ return ret
+
def clean_up(self):
return
@@ -557,6 +671,10 @@ def handle_exception(loop, context):
args,
router.profiler_cmd_queue,
)
+ router.rl_rpyc_server, router.rl_rpyc_thread = start_router_rl_rpyc_server(
+ args,
+ router.rl_op_queue,
+ )
except:
import traceback
import sys
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 5c2d0d45fb..2f79c441bb 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -23,6 +23,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.pd_io_struct import PDDecodeNodeInfo
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
+from lightllm.common.basemodel import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -121,6 +122,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
return req_objs
+ def _extract_routing_data(self, req: "InferReq"):
+ if not (req.shm_req.finish_status.is_finished() or req.shm_req.stop_str_matched):
+ return
+ mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]
+ mgr = _routing_mgr.g_routing_capture_manager
+ routing_data = mgr.extract_routing_data(mem_indexes)
+ req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype)
+ req.shm_req.shm_routing_data.arr[:] = routing_data
+ req.shm_req.shm_routing_data.detach_shm()
+
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
if self.radix_cache is None:
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
@@ -274,6 +285,8 @@ def _filter(self, finished_request_ids: List[int]):
req: InferReq = self.requests_mapping.pop(request_id)
if self.args.diverse_mode:
req.clear_master_slave_state()
+ if _routing_mgr.g_routing_capture_manager is not None:
+ g_infer_context._extract_routing_data(req)
self.free_a_req_mem(free_token_index, req)
free_req_index.append(req.req_idx)
@@ -849,6 +862,8 @@ def update_finish_status(self, eos_ids, output_len: int):
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
elif output_len >= self.sampling_param.shm_param.max_new_tokens:
self.finish_status.set_status(FinishStatus.FINISHED_LENGTH)
+ elif self.infer_aborted:
+ self.finish_status.set_status(FinishStatus.FINISHED_ABORTED)
return
def _stop_sequences_matched(self, output_len: int):
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index a65dfb1bbb..25b5b7e1a4 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -4,7 +4,7 @@
import time
import threading
import torch.distributed as dist
-from typing import List, Tuple, Callable, Optional
+from typing import List, Tuple, Callable, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.log_utils import init_logger
@@ -33,8 +33,8 @@
enable_radix_tree_timer_merge,
get_radix_tree_merge_update_delta,
)
+from lightllm.distributed import dist_group_manager
from lightllm.distributed.communication_op import (
- dist_group_manager,
all_gather_into_tensor,
all_reduce,
broadcast,
@@ -597,7 +597,7 @@ def _get_classed_reqs(
paused_reqs.append(req_obj)
continue
- if req_obj.infer_aborted or req_obj.finish_status.is_finished():
+ if req_obj.finish_status.is_finished():
if support_overlap:
# 延迟处理
req_obj.filter_mark = True
diff --git a/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py b/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py
new file mode 100644
index 0000000000..aab832050f
--- /dev/null
+++ b/lightllm/server/router/model_infer/mode_backend/rl_backend_ops.py
@@ -0,0 +1,343 @@
+import gc
+from typing import List, Optional
+
+import torch
+
+from lightllm.utils.dist_utils import init_custom_process_group
+from lightllm.utils.rl.serialization import LocalSerializedTensor, MultiprocessingSerializer
+from lightllm.utils.rl.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata
+from lightllm.utils.rl.torch_cuda_ipc import cuda_rebuild_device_fallback, monkey_patch_torch_reductions
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+from lightllm.server.io_struct import (
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromIPCReq,
+ UpdateWeightsFromTensorReq,
+)
+
+
+class RlBackendOps:
+ MEMORY_TAG_ORDER = (MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH)
+
+ SUPPORTED = frozenset(
+ {
+ "flush_cache",
+ "release_memory_occupation",
+ "resume_memory_occupation",
+ "init_weights_update_group",
+ "destroy_weights_update_group",
+ "update_weights_from_distributed",
+ "update_weights_from_tensor",
+ "update_weights_from_ipc",
+ }
+ )
+
+ def __init__(self, backend) -> None:
+ self.backend = backend
+ self._model_update_group = {}
+ self._skip_tensor_updates_reason = None
+ self.logger = backend.logger
+
+ @classmethod
+ def supports(cls, func_name: str) -> bool:
+ return func_name in cls.SUPPORTED
+
+ def dispatch(self, func_name: str, func_args):
+ if not self.supports(func_name):
+ raise ValueError(f"RlBackendOps does not support function {func_name}")
+ return getattr(self, func_name)(func_args)
+
+ def flush_cache(self, request: FlushCacheReq):
+ if self.backend.radix_cache is not None:
+ self.backend.radix_cache.flush_cache()
+ return True, "Succeeded to flush cache."
+
+ def _iter_memory_tags(self, tags: Optional[List[MemoryTag]]):
+ return self.MEMORY_TAG_ORDER if tags is None else tags
+
+ def _clear_cuda_cache(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _pause_memory_tags(self, tags: Optional[List[MemoryTag]]):
+ torch.cuda.synchronize()
+ for tag in self._iter_memory_tags(tags):
+ self.backend.model.torch_memory_saver.pause(tag=tag)
+ self._clear_cuda_cache()
+
+ def _resume_memory_tags(self, tags: Optional[List[MemoryTag]]):
+ self._clear_cuda_cache()
+ for tag in self._iter_memory_tags(tags):
+ self.backend.model.torch_memory_saver.resume(tag=tag)
+
+ def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ try:
+ self._pause_memory_tags(tags)
+ self.flush_cache(request=None)
+ return True, "Succeeded to release memory occupation."
+ except Exception as e:
+ self.logger.error(f"release memory occupation failed: {str(e)}")
+ return False, f"release memory occupation failed: {str(e)}"
+
+ def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ try:
+ self._resume_memory_tags(tags)
+ return True, "Succeeded to resume memory occupation."
+ except Exception as e:
+ self.logger.error(f"resume memory occupation failed: {str(e)}")
+ return False, f"resume memory occupation failed: {str(e)}"
+
+ def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ assert torch.distributed.is_initialized(), "Default torch process group must be initialized"
+
+ assert request.group_name != "", "Group name cannot be empty"
+ rank_offset = request.rank_offset
+ rank = rank_offset + self.backend.rank_in_dp
+ world_size = request.world_size
+ group_name = request.group_name
+ self.logger.info(
+ f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, "
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, "
+ f" backend={request.backend}"
+ )
+
+ try:
+ if group_name in self._model_update_group:
+ raise ValueError(f"Process group with name {group_name} already exists.")
+
+ self._model_update_group[group_name] = init_custom_process_group(
+ backend=request.backend,
+ init_method=f"tcp://{request.master_address}:{request.master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+ return True, "Succeeded to initialize custom process group."
+
+ except Exception as e:
+ message = f"Failed to initialize custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ try:
+ if request.group_name in self._model_update_group:
+ pg = self._model_update_group.pop(request.group_name)
+ torch.distributed.destroy_process_group(pg)
+ return True, "Succeeded to destroy custom process group."
+ else:
+ return False, "The group to be destroyed does not exist."
+ except Exception as e:
+ message = f"Failed to destroy custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+ """
+ Update model weights online through the custom weight update process group.
+ """
+
+ assert request.group_name in self._model_update_group, (
+ f"Group {request.group_name} not in {list(self._model_update_group.keys())}. "
+ "Please call `init_weights_update_group` first."
+ )
+
+ try:
+ weights = {}
+ handles = []
+ for name, dtype, shape in zip(request.names, request.dtypes, request.shapes):
+ target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ weight = torch.empty(shape, dtype=target_dtype, device="cuda")
+ handles.append(
+ torch.distributed.broadcast(
+ weight,
+ src=0,
+ group=self._model_update_group[request.group_name],
+ async_op=True,
+ )
+ )
+ weights[name] = weight
+ for handle in handles:
+ handle.wait()
+
+ self.backend.model.load_weights(weights)
+ return True, "Succeeded to update parameter online from distributed."
+
+ except Exception as e:
+ error_msg = (
+ f"Failed to update parameter online: {e}. "
+ f"The full weights of the ModelRunner are partially updated. "
+ f"Please discard the whole weights."
+ )
+ self.logger.error(error_msg)
+ return False, error_msg
+
+ def _update_weights_from_flattened_bucket(
+ self,
+ flattened_tensor_bucket_dict,
+ ):
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
+ metadata = flattened_tensor_bucket_dict["metadata"]
+
+ converted_metadata = []
+ for meta in metadata:
+ if isinstance(meta, dict):
+ converted_meta = FlattenedTensorMetadata(
+ name=meta["name"],
+ shape=meta["shape"],
+ dtype=meta["dtype"],
+ start_idx=meta["start_idx"],
+ end_idx=meta["end_idx"],
+ numel=meta["numel"],
+ )
+ else:
+ converted_meta = FlattenedTensorMetadata(
+ name=meta.name,
+ shape=meta.shape,
+ dtype=meta.dtype,
+ start_idx=meta.start_idx,
+ end_idx=meta.end_idx,
+ numel=meta.numel,
+ )
+ converted_metadata.append(converted_meta)
+
+ bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata)
+ reconstructed_tensors = bucket.reconstruct_tensors()
+
+ named_tensors = {name: tensor for name, tensor in reconstructed_tensors}
+ loaded, skipped = self._load_compatible_named_tensors(named_tensors)
+
+ return (
+ True,
+ "Succeeded to update parameter online from flattened bucket tensor. "
+ f"loaded={loaded}, skipped={skipped}.",
+ )
+
+ @staticmethod
+ def _iter_named_tensors(named_tensors):
+ if isinstance(named_tensors, dict):
+ return named_tensors.items()
+ return named_tensors
+
+ def _get_tensor_update_skip_reason(self, items):
+ if self._skip_tensor_updates_reason is not None:
+ return self._skip_tensor_updates_reason
+
+ target_config = getattr(self.backend.model, "config", {}) or {}
+ target_is_moe = bool(target_config.get("num_experts") or target_config.get("moe_intermediate_size"))
+ source_has_moe_experts = any(".mlp.experts." in name for name, _ in items)
+ if source_has_moe_experts and not target_is_moe:
+ self._skip_tensor_updates_reason = "received MoE expert weights for a non-MoE backend"
+ self.logger.warning("skip tensor weight updates: %s", self._skip_tensor_updates_reason)
+ return self._skip_tensor_updates_reason
+
+ return None
+
+ def _load_compatible_named_tensors(self, named_tensors):
+ items = list(self._iter_named_tensors(named_tensors))
+ skip_reason = self._get_tensor_update_skip_reason(items)
+ if skip_reason is not None:
+ return 0, len(items)
+
+ def _load_range(weight_items):
+ if not weight_items:
+ return 0, 0
+ weight_dict = dict(weight_items)
+ try:
+ self.backend.model.load_weights(weight_dict)
+ return len(weight_items), 0
+ except Exception as e:
+ if len(weight_items) == 1:
+ name, tensor = weight_items[0]
+ self.logger.warning(
+ "skip incompatible tensor update %s shape=%s dtype=%s: %s",
+ name,
+ tuple(tensor.shape) if hasattr(tensor, "shape") else None,
+ getattr(tensor, "dtype", None),
+ e,
+ )
+ return 0, 1
+
+ split_idx = len(weight_items) // 2
+ left_loaded, left_skipped = _load_range(weight_items[:split_idx])
+ right_loaded, right_skipped = _load_range(weight_items[split_idx:])
+ return left_loaded + right_loaded, left_skipped + right_skipped
+
+ return _load_range(items)
+
+ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq):
+ try:
+ monkey_patch_torch_reductions()
+ device_module = torch.get_device_module("cuda")
+ infered_device = device_module.current_device()
+
+ if request.load_format == "flattened_bucket":
+ with cuda_rebuild_device_fallback(infered_device):
+ serialized_named_tensors = MultiprocessingSerializer.deserialize(
+ request.serialized_named_tensors[self.backend.rank_in_dp]
+ )
+ return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors)
+
+ def _unwrap_tensor(tensor, tp_rank, device):
+ if isinstance(tensor, LocalSerializedTensor):
+ tensor = tensor.get(tp_rank)
+ clone = tensor.to(device).clone()
+ del tensor
+ return clone
+
+ with cuda_rebuild_device_fallback(infered_device):
+ named_tensors = MultiprocessingSerializer.deserialize(
+ request.serialized_named_tensors[self.backend.rank_in_dp]
+ )
+ named_tensors = {
+ name: _unwrap_tensor(tensor, tp_rank=self.backend.rank_in_dp, device=infered_device)
+ for name, tensor in self._iter_named_tensors(named_tensors)
+ }
+
+ loaded, skipped = self._load_compatible_named_tensors(named_tensors)
+
+ return True, f"Succeeded to update parameter online from tensor. loaded={loaded}, skipped={skipped}."
+
+ except Exception as e:
+ message = f"Failed to update parameter online from tensor. Reason: {e}."
+ self.logger.error(message)
+
+ return False, message
+
+ def update_weights_from_ipc(self, request: UpdateWeightsFromIPCReq):
+ try:
+ from lightllm.utils.rl.bucketed_weight_transfer import BucketedWeightReceiver, get_zmq_handle
+
+ zmq_handle = request.ipc_handle
+ if isinstance(zmq_handle, dict):
+ zmq_handle = zmq_handle.get(self.backend.rank_in_node, zmq_handle.get(str(self.backend.rank_in_node)))
+ if zmq_handle is None:
+ raise ValueError(f"Missing ipc_handle for rank_in_node={self.backend.rank_in_node}")
+ if zmq_handle in (None, "", "auto"):
+ zmq_handle = get_zmq_handle()
+ use_shm = request.use_shm
+ recv_device = torch.device("cuda", self.backend.current_device_id)
+ self.logger.debug(
+ "[LightLLM] RlBackendOps.update_weights_from_ipc: request.ipc_handle=%r, "
+ "resolved zmq_handle=%r, cuda_device_id=%s",
+ request.ipc_handle,
+ zmq_handle,
+ self.backend.current_device_id,
+ )
+
+ bucketed_weight_receiver = BucketedWeightReceiver(
+ zmq_handle=zmq_handle, device=recv_device, use_shm=use_shm
+ )
+ bucketed_weight_receiver.receive_weights(on_bucket_received=self.backend.model.load_weights)
+ return True, "Succeeded to update parameter online from ipc."
+
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ message = f"Failed to update parameter online from ipc. Reason: {e}."
+ self.logger.error(message)
+
+ return False, message
diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py
index 864a7405b7..1527b69d2b 100644
--- a/lightllm/server/router/model_infer/model_rpc.py
+++ b/lightllm/server/router/model_infer/model_rpc.py
@@ -28,11 +28,14 @@
PDDPForDecodeNode,
)
from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager
+from lightllm.server.router.model_infer.mode_backend.rl_backend_ops import RlBackendOps
from lightllm.server.core.objs.start_args_type import StartArgs
from lightllm.utils.log_utils import init_logger
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp
logger = init_logger(__name__)
@@ -46,6 +49,8 @@ def __init__(self, args, rank: int, rank_in_node: int, node_world_size: int, inf
self.rank = rank
self.rank_in_node = rank_in_node
+ self.backend = None
+ self.rl_backend_ops = None
logger.info(f"Initialized RPC server for rank {self.rank}.")
return
@@ -99,6 +104,7 @@ def exposed_init_model(self, kvargs):
logger.info(f"use {self.backend.__class__.__name__}")
self.backend.init_model(kvargs)
+ self.rl_backend_ops = RlBackendOps(self.backend)
# only deepseekv3 can support auto_update_redundancy_expert
if self.args.auto_update_redundancy_expert:
@@ -111,6 +117,24 @@ def exposed_init_model(self, kvargs):
def exposed_get_max_total_token_num(self):
return self.backend.get_max_total_token_num()
+ def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ try:
+ req = obtain(req)
+ if self.rl_backend_ops is None:
+ raise ValueError("RL backend ops is not initialized")
+ if not RlBackendOps.supports(req.func_name):
+ raise ValueError(
+ f"Unsupported RL backend function {req.func_name}. "
+ f"Supported functions: {sorted(RlBackendOps.SUPPORTED)}"
+ )
+ success, ret = self.rl_backend_ops.dispatch(req.func_name, req.func_args)
+ return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret)
+ except BaseException as e:
+ logger.exception(f"forward to model backend failed: {str(e)}")
+ return GeneralModelToHttpRpcRsp(
+ success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name
+ )
+
class ModelRpcClient:
def __init__(self, conn):
@@ -134,6 +158,7 @@ async def _func(*args, **kwargs):
self._init_model = async_wrap(self.conn.root.init_model)
self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num)
+ self._forward_to_model = async_wrap(self.conn.root.forward_to_model)
return
async def init_model(self, kvargs):
@@ -145,6 +170,10 @@ async def get_max_total_token_num(self):
ans = self._get_max_total_token_num()
return obtain(await ans)
+ async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ ans = self._forward_to_model(req)
+ return obtain(await ans)
+
def _init_env(
args,
@@ -197,7 +226,11 @@ async def start_model_process(
success_event,
),
)
- proc.start()
+ from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper
+
+ torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver)
+ with torch_memory_saver.configure_subprocess():
+ proc.start()
# Use asyncio.to_thread to make the blocking wait non-blocking
await asyncio.to_thread(success_event.wait, timeout=40)
diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py
index 0d1ffe6967..5a31c77de6 100644
--- a/lightllm/server/router/req_queue/base_queue.py
+++ b/lightllm/server/router/req_queue/base_queue.py
@@ -4,6 +4,9 @@
from lightllm.server.core.objs import FinishStatus
from lightllm.utils.config_utils import get_fixed_kv_len
from lightllm.server.core.objs import StartArgs
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
class BaseQueue:
@@ -32,6 +35,30 @@ def free_aborted_req_cpu_cache_pages(self, req: Req):
req.cpu_cache_match_page_indexes.clear()
self.router.cpu_cache_client.lock.release()
+ def should_release_aborted_req_in_queue(self, req: Req):
+ # 多节点 TP 的 waiting req abort 状态必须先由 rank 0 broadcast 对齐,
+ # 不能在各节点本地调度队列里提前按各自 shm 状态释放。
+ return req.is_aborted and not self.router.is_multinode_tp
+
+ def mark_aborted_req_finished(self, req: Req):
+ # 未开始推理的请求没有生成 token;这里写入一个 EOS 位置和 aborted 状态,
+ # 让 httpserver recycle loop 能正常结束请求并返回空字符串。
+ input_len = req.input_len
+ req.link_prompt_ids_shm_array()
+ req.link_logprobs_shm_array()
+ req.candetoken_out_len = 1
+ req.finish_token_index = input_len
+ req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0]
+ req.shm_logprobs.arr[input_len] = 0
+ req.finish_status.set_status(FinishStatus.FINISHED_ABORTED)
+
+ def release_aborted_req(self, req: Req):
+ logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
+ self.free_aborted_req_cpu_cache_pages(req)
+ self.mark_aborted_req_finished(req)
+ self.router.shm_req_manager.put_back_req_obj(req)
+ return
+
def extend(self, req_group: List[Req]):
for req in req_group:
req.sample_params.suggested_dp_index = self.dp_index
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py
index e82cc7e181..4a209ba7be 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py
@@ -2,9 +2,6 @@
import numpy as np
from ...batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue
-from lightllm.utils.log_utils import init_logger
-
-logger = init_logger(__name__)
class ChunkedPrefillQueue(BaseQueue):
@@ -73,21 +70,20 @@ def generate_new_batch(self, current_batch: Batch):
self._init_cache_list(current_batch, is_busy)
can_run_list = []
abort_req_list = []
- aborted_count = 0
+ consumed_req_count = 0
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
- # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
- # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
- aborted_count += 1
+ if self.should_release_aborted_req_in_queue(req):
+ consumed_req_count += 1
abort_req_list.append(req)
continue
ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req(
req, is_busy, new_batch_first_router_need_tokens
)
if ok_insert:
+ consumed_req_count += 1
can_run_list.append(req)
else:
break
@@ -95,11 +91,8 @@ def generate_new_batch(self, current_batch: Batch):
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
for req in abort_req_list:
- req: Req = req
- logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
- self.free_aborted_req_cpu_cache_pages(req)
- self.router.shm_req_manager.put_back_req_obj(req)
- self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
+ self.release_aborted_req(req)
+ self.waiting_req_list = self.waiting_req_list[consumed_req_count:]
return new_batch
def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py
index 5ec09f5760..11e4198fb2 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd.py
@@ -3,9 +3,6 @@
from typing import Tuple
from ...batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue
-from lightllm.utils.log_utils import init_logger
-
-logger = init_logger(__name__)
class PDQueue(BaseQueue):
@@ -69,21 +66,20 @@ def generate_new_batch(self, current_batch: Batch):
can_run_list = []
abort_req_list = []
- aborted_count = 0
+ consumed_req_count = 0
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
- # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
- # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
- aborted_count += 1
+ if self.should_release_aborted_req_in_queue(req):
+ consumed_req_count += 1
abort_req_list.append(req)
continue
ok_insert, estimated_peak_token_num, batch_req_num = self._can_add_new_req(
req=req, estimated_peak_token_num=estimated_peak_token_num, batch_req_num=batch_req_num
)
if ok_insert:
+ consumed_req_count += 1
can_run_list.append(req)
else:
break
@@ -91,11 +87,8 @@ def generate_new_batch(self, current_batch: Batch):
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
for req in abort_req_list:
- req: Req = req
- logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
- self.free_aborted_req_cpu_cache_pages(req)
- self.router.shm_req_manager.put_back_req_obj(req)
- self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
+ self.release_aborted_req(req)
+ self.waiting_req_list = self.waiting_req_list[consumed_req_count:]
return new_batch
def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py
index 866e1b9f42..af8f875d4e 100644
--- a/lightllm/server/router/req_queue/dp_base_queue.py
+++ b/lightllm/server/router/req_queue/dp_base_queue.py
@@ -26,6 +26,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
self.reqs_waiting_for_dp_index: List[List[Req]] = []
return
+ def release_aborted_req(self, req: Req):
+ dp_index = req.sample_params.suggested_dp_index
+ assert dp_index >= 0 and dp_index < self.dp_size_in_node
+ self.inner_queues[dp_index].release_aborted_req(req)
+ return
+
def get_dp_queue(self, dp_index: int):
assert dp_index < self.dp_size_in_node, "dp index out of range"
return self.inner_queues[dp_index]
diff --git a/lightllm/server/router/rl_rpyc.py b/lightllm/server/router/rl_rpyc.py
new file mode 100644
index 0000000000..86a32df0a2
--- /dev/null
+++ b/lightllm/server/router/rl_rpyc.py
@@ -0,0 +1,66 @@
+import concurrent.futures
+import queue
+import threading
+from typing import List, Tuple
+
+import rpyc
+from rpyc.utils.classic import obtain
+
+from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class RouterRlOpQueue:
+ def __init__(self):
+ self._queue: "queue.Queue[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]]" = queue.Queue()
+
+ def submit(self, req: GeneralHttpToModelRpcReq, timeout: float = 300.0) -> GeneralModelToHttpRpcRsp:
+ fut: concurrent.futures.Future = concurrent.futures.Future()
+ self._queue.put((req, fut))
+ try:
+ return fut.result(timeout=timeout)
+ except concurrent.futures.TimeoutError:
+ return GeneralModelToHttpRpcRsp(
+ success=False,
+ msg=f"rl op {req.func_name} timeout after {timeout}s",
+ func_name=req.func_name,
+ )
+
+ def pop_all(self) -> List[Tuple[GeneralHttpToModelRpcReq, concurrent.futures.Future]]:
+ pairs = []
+ while True:
+ try:
+ pairs.append(self._queue.get_nowait())
+ except queue.Empty:
+ break
+ return pairs
+
+
+class RouterRlRpcService(rpyc.Service):
+ def __init__(self, rl_op_queue: RouterRlOpQueue):
+ super().__init__()
+ self.rl_op_queue = rl_op_queue
+
+ def exposed_rl_op(self, req: GeneralHttpToModelRpcReq):
+ return self.rl_op_queue.submit(obtain(req))
+
+
+def start_router_rl_rpyc_server(args, rl_op_queue: RouterRlOpQueue):
+ if args.node_rank != 0 or args.rl_rpyc_port is None:
+ return None, None
+
+ from rpyc.utils.server import ThreadedServer
+ import lightllm.utils.rpyc_fix_utils as _
+
+ server = ThreadedServer(
+ RouterRlRpcService(rl_op_queue),
+ hostname="127.0.0.1",
+ port=args.rl_rpyc_port,
+ protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
+ )
+ thread = threading.Thread(target=server.start, name="rl_rpyc_server", daemon=True)
+ thread.start()
+ logger.info(f"router rl rpyc server started on port {args.rl_rpyc_port}")
+ return server, thread
diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py
index 3e74793634..b08e2d13e2 100644
--- a/lightllm/server/visualserver/model_infer/__init__.py
+++ b/lightllm/server/visualserver/model_infer/__init__.py
@@ -11,6 +11,7 @@
from rpyc.utils.server import ThreadedServer
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name
+from lightllm.utils.process_check import start_parent_check_thread
from .model_rpc_client import VisualModelRpcClient
from .model_rpc import VisualModelRpcServer
from ..objs import rpyc_config
@@ -20,6 +21,7 @@ def _init_env(socket_path: str, success_event):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer")
+ start_parent_check_thread()
import lightllm.utils.rpyc_fix_utils as _
diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py
index 5b9705ed0e..87aae86e4f 100644
--- a/lightllm/utils/dist_utils.py
+++ b/lightllm/utils/dist_utils.py
@@ -80,12 +80,15 @@ def init_vision_distributed_env(kvargs):
device_id = kvargs["device_id"]
set_current_device_id(device_id)
torch.cuda.set_device(device_id)
+ # 不要在init_process_group时,显示的传入device_id
+ # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank
+ # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行
+ # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。
dist.init_process_group(
"nccl",
init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}',
rank=kvargs["tp_rank_id"],
world_size=tp_world_size,
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -150,7 +153,6 @@ def init_distributed_env(kvargs):
init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}',
rank=kvargs["rank_id"],
world_size=kvargs["world_size"],
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -316,3 +318,71 @@ def _init_nccl_env():
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
return
+
+
+# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675
+def init_custom_process_group(
+ backend=None,
+ init_method=None,
+ timeout=None,
+ world_size=-1,
+ rank=-1,
+ store=None,
+ group_name=None,
+ pg_options=None,
+ device_id=None,
+):
+ from torch.distributed.distributed_c10d import (
+ Backend,
+ PrefixStore,
+ _new_process_group_helper,
+ _world,
+ default_pg_timeout,
+ rendezvous,
+ )
+
+ assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
+
+ if store is not None:
+ assert world_size > 0, "world_size must be positive if using store"
+ assert rank >= 0, "rank must be non-negative if using store"
+ elif init_method is None:
+ init_method = "env://"
+
+ if backend:
+ backend = Backend(backend)
+ else:
+ backend = Backend("undefined")
+
+ if timeout is None:
+ timeout = default_pg_timeout
+
+ # backward compatible API
+ if store is None:
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+
+ # Use a PrefixStore to avoid accidental overrides of keys used by
+ # different systems (e.g. RPC) in case the store is multi-tenant.
+ store = PrefixStore(group_name, store)
+
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
+ # We need to determine the appropriate parameter name based on PyTorch version
+ pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
+ pg, _ = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ **{pg_options_param_name: pg_options},
+ timeout=timeout,
+ device_id=device_id,
+ )
+
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
+
+ return pg
diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py
index b87096d945..def49a2fa5 100644
--- a/lightllm/utils/net_utils.py
+++ b/lightllm/utils/net_utils.py
@@ -2,14 +2,28 @@
import subprocess
import ipaddress
import random
+import os
from lightllm.utils.log_utils import init_logger
logger = init_logger(__name__)
+DEFAULT_BASE_PORT = 10000
+PORTS_PER_INSTANCE = 1000
+MAX_INSTANCE_ID = 7
+
+
+def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=DEFAULT_BASE_PORT, instance_id=0):
+ if instance_id < 0 or instance_id > MAX_INSTANCE_ID:
+ raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}")
+
+ base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", from_port_num))
+ # Keep independent launchers away from the same free-port window, especially for NCCL TCPStore ports.
+ range_start = base_port + instance_id * PORTS_PER_INSTANCE
+ range_end = range_start + PORTS_PER_INSTANCE
+ used_ports = used_ports or []
-def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000):
port_list = []
- for port in range(from_port_num, 65536):
+ for port in range(range_start, range_end):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
result = s.connect_ex(("localhost", port))
if result != 0 and port not in used_ports:
diff --git a/lightllm/utils/rl/__init__.py b/lightllm/utils/rl/__init__.py
new file mode 100644
index 0000000000..12d5ab89e5
--- /dev/null
+++ b/lightllm/utils/rl/__init__.py
@@ -0,0 +1,6 @@
+"""
+Utilities used by RL weight-update and colocated rollout integration paths.
+
+This package groups the CUDA IPC serializer, tensor bucketing helpers, and
+bucketed weight-transfer protocol used by the RL endpoints.
+"""
diff --git a/lightllm/utils/rl/bucketed_weight_transfer.py b/lightllm/utils/rl/bucketed_weight_transfer.py
new file mode 100644
index 0000000000..4849497f1a
--- /dev/null
+++ b/lightllm/utils/rl/bucketed_weight_transfer.py
@@ -0,0 +1,302 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Bucketed RL weight transfer via ZMQ plus CUDA IPC or shared memory fallback.
+
+This module builds on torch_cuda_ipc.py for CUDA IPC device handling. It owns
+the higher-level transfer protocol: the sender publishes one reusable
+communication buffer, and the receiver rebuilds that buffer on its target
+device before applying each metadata-described bucket to model weights.
+
+Copied from:
+https://github.com/verl-project/verl/blob/main/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py
+"""
+
+import gc
+from multiprocessing import shared_memory
+from typing import TypedDict
+
+import torch
+import zmq
+from torch.multiprocessing.reductions import reduce_tensor
+from lightllm.utils.rl.torch_cuda_ipc import (
+ cuda_device_to_uuid,
+ get_current_device_id,
+ get_current_device_name,
+ rebuild_cuda_ipc_tensor,
+)
+
+
+def get_zmq_handle() -> str:
+ return f"ipc:///tmp/rl-colocate-zmq-{cuda_device_to_uuid(get_current_device_id())}.sock"
+
+
+class TensorMetadata(TypedDict):
+ name: str
+ shape: torch.Size
+ dtype: torch.dtype
+ offset: int
+
+
+def create_shared_memory(size: int, name: str):
+ """Create shared memory for weight transfer. If already exists, attach to it."""
+ try:
+ shm = shared_memory.SharedMemory(name=name, create=True, size=size)
+ except FileExistsError:
+ shm = shared_memory.SharedMemory(name=name)
+ assert shm.size >= size, f"Stale shm segment '{name}': expected {size} bytes, got {shm.size}"
+ return shm
+
+
+def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8):
+ """Rebuild tensor from shared memory."""
+ shm = shared_memory.SharedMemory(name=name)
+ tensor = torch.frombuffer(shm.buf[:size], dtype=dtype)
+
+ return tensor, shm
+
+
+class BucketedWeightSender:
+ """
+ Send model weights via bucketed IPC transfer over ZMQ.
+
+ Packs weight tensors into a fixed-size communication buffer and sends them
+ in buckets to the receiver. Supports CUDA IPC and shared memory fallback.
+
+ Args:
+ zmq_handle: ZMQ IPC socket path (e.g., "ipc:///tmp/rl-colocate-zmq-.sock")
+ bucket_size_mb: Communication buffer size in MB
+ use_shm: Use shared memory instead of CUDA IPC (for NPU compatibility)
+ """
+
+ def __init__(
+ self,
+ zmq_handle: str,
+ bucket_size_mb: int = 512,
+ use_shm: bool = False,
+ ):
+ self.zmq_handle = zmq_handle
+ self.bucket_size_mb = bucket_size_mb
+ self.bucket_size = int(bucket_size_mb) << 20
+ self.use_shm = use_shm
+
+ self.zmq_context = zmq.Context.instance()
+ self.socket = None
+ self.buffer = None
+ self.shm = None
+
+ async def async_send_weights(self, weights):
+ """
+ Send weights to the receiver. Accepts a sync generator or async iterator.
+
+ Args:
+ weights: Generator or async iterator yielding (name, tensor) pairs
+ """
+ from verl.workers.rollout.utils import ensure_async_iterator
+
+ try:
+ self._init_socket()
+ self._init_buffer()
+
+ # send bucket weights
+ offset = 0
+ bucket_meta: dict[str, TensorMetadata] = {}
+ # dtype = PrecisionType.to_dtype(self.config.dtype)
+ async for name, weight in ensure_async_iterator(weights):
+ # model parameters are in fp32 full precision
+ # (vermouth1992) we should not force cast weight here because some parameters
+ # (such as moe gate) have to keep fp32 precision. If a weight is bf16 in the rollout side,
+ # the rollout should automatically cast on demand. However, this would incur a higher weight
+ # transfer volume.
+ # weight = weight.to(dtype, non_blocking=True)
+
+ # fill the tensor bucket
+ if offset + weight.nbytes > self.bucket_size:
+ torch.cuda.synchronize()
+ self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False})
+ self.socket.recv()
+ bucket_meta = {}
+ offset = 0
+
+ # TODO: slice embedding layer weight into chunks
+ assert offset + weight.nbytes <= self.bucket_size, (
+ f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket."
+ f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)."
+ )
+ bucket_meta[name] = {
+ "name": name,
+ "shape": weight.shape,
+ "dtype": weight.dtype,
+ "offset": offset,
+ }
+ self.buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
+ offset += weight.nbytes
+
+ # send the last bucket
+ torch.cuda.synchronize()
+ self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True})
+ self.socket.recv()
+ finally:
+ self._cleanup()
+
+ def _init_socket(self):
+ """Initialize ZMQ REQ socket and bind."""
+ self.socket = self.zmq_context.socket(zmq.REQ)
+ self.socket.bind(self.zmq_handle)
+
+ def _init_buffer(self):
+ """build communication buffer"""
+ buffer, shm = None, None
+ if not self.use_shm:
+ buffer = torch.empty(
+ self.bucket_size,
+ dtype=torch.uint8,
+ device=f"{get_current_device_name()}:{get_current_device_id()}",
+ )
+ handle = reduce_tensor(buffer)
+ self.socket.send_pyobj(handle)
+ else:
+ import uuid
+
+ # Create unique name for shared memory
+ shm_name = f"verl_weights_{uuid.uuid4().hex}"
+ shm = create_shared_memory(self.bucket_size, shm_name)
+ buffer = torch.frombuffer(shm.buf, dtype=torch.uint8)
+
+ comm_metadata = {"name": shm_name, "size": self.bucket_size}
+ self.socket.send_pyobj(comm_metadata)
+
+ self.socket.recv()
+ self.buffer = buffer
+ self.shm = shm
+
+ def _cleanup(self):
+ """clean up"""
+ if self.socket is not None:
+ self.socket.close()
+ self.socket = None
+ del self.buffer
+ self.buffer = None
+ if self.shm is not None:
+ self.shm.close()
+ self.shm.unlink()
+ del self.shm
+ self.shm = None
+ gc.collect()
+ torch.cuda.ipc_collect()
+ torch.cuda.empty_cache()
+
+
+class BucketedWeightReceiver:
+ """
+ Receive model weights via bucketed IPC transfer over ZMQ.
+
+ Receives weight tensors from BucketedWeightSender and passes each
+ bucket to a callback for processing (e.g., loading into the model).
+
+ Args:
+ zmq_handle: ZMQ IPC socket path (must match sender)
+ device: Target device for received tensors
+ use_shm: Use shared memory instead of CUDA IPC
+ """
+
+ def __init__(
+ self,
+ zmq_handle: str,
+ device: torch.device,
+ use_shm: bool = False,
+ ):
+ self.zmq_handle = zmq_handle
+ self.device = device
+ self.use_shm = use_shm
+
+ self.zmq_context = zmq.Context.instance()
+ self.socket = None
+ self.buffer = None
+ self.shm = None
+
+ def receive_weights(self, on_bucket_received: callable):
+ """
+ Receive weights from sender and process each bucket via callback.
+
+ Args:
+ on_bucket_received: Callback function(weight_dict: dict[str, torch.Tensor]) called per bucket.
+ """
+ try:
+ self._init_socket()
+ self._init_buffer()
+
+ # receive bucket and update weights
+ while True:
+ metadata = self.socket.recv_pyobj()
+ weights, tensor = [], None
+ for name, meta in metadata["bucket_meta"].items():
+ shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"]
+ size = dtype.itemsize * shape.numel()
+ # NOTE: we need to clone the tensor to release CUDA IPC memory
+ # but for shared memory, it's not necessary and if we do clone,
+ # it will cause extra memory copy overhead and slow down the process.
+ tensor = self.buffer[offset : offset + size].view(dtype=dtype).view(shape)
+ if not self.use_shm:
+ tensor = tensor.clone()
+ else:
+ tensor = tensor.to(self.device)
+ weights.append((name, tensor))
+ torch.cuda.synchronize()
+ self.socket.send(b"")
+ on_bucket_received(dict(weights))
+ del weights, tensor
+ if metadata["is_last"]:
+ break
+ finally:
+ self._cleanup()
+
+ def _init_socket(self):
+ """Initialize ZMQ REP socket and connect."""
+ self.socket = self.zmq_context.socket(zmq.REP)
+ self.socket.connect(self.zmq_handle)
+
+ def _init_buffer(self):
+ """Receive and rebuild communication buffer from sender."""
+ comm_metadata = self.socket.recv_pyobj()
+ buffer, shm = None, None
+ if not self.use_shm:
+ handle = comm_metadata
+ buffer = rebuild_cuda_ipc_tensor(handle, self.device.index)
+ assert buffer.dtype == torch.uint8
+ else:
+ shm_name = comm_metadata["name"]
+ shm_size = comm_metadata["size"]
+ buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8)
+ self.socket.send(b"")
+ self.buffer = buffer
+ self.shm = shm
+
+ def _cleanup(self):
+ """clean up"""
+ if self.socket is not None:
+ self.socket.close()
+ self.socket = None
+ # Synchronize before releasing the buffer to ensure all async ops
+ # referencing it (e.g. clone, .to()) have completed.
+ torch.cuda.synchronize()
+ del self.buffer
+ self.buffer = None
+ if self.shm is not None:
+ self.shm.close()
+ del self.shm
+ self.shm = None
+ gc.collect()
+ torch.cuda.ipc_collect()
+ torch.cuda.empty_cache()
diff --git a/lightllm/utils/rl/serialization.py b/lightllm/utils/rl/serialization.py
new file mode 100644
index 0000000000..e5be3f0968
--- /dev/null
+++ b/lightllm/utils/rl/serialization.py
@@ -0,0 +1,140 @@
+"""
+Serialization helpers for RL weight-update requests.
+
+The tensor update endpoint passes CUDA tensors across processes by serializing
+their multiprocessing IPC handles rather than copying tensor data into the HTTP
+payload. This module wraps ForkingPickler for that handoff and uses a guarded
+unpickler because the payload may come from a trainer process.
+
+Copied from:
+https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py
+"""
+import base64
+import pickle
+import io
+from dataclasses import dataclass
+from multiprocessing.reduction import ForkingPickler
+from typing import List
+
+
+class MultiprocessingSerializer:
+ @staticmethod
+ def serialize(obj, output_str: bool = False):
+ """
+ Serialize a Python object using ForkingPickler.
+
+ Args:
+ obj: The object to serialize.
+ output_str (bool): If True, return a base64-encoded string instead of raw bytes.
+
+ Returns:
+ bytes or str: The serialized object.
+ """
+ buf = io.BytesIO()
+ ForkingPickler(buf).dump(obj)
+ buf.seek(0)
+ output = buf.read()
+
+ if output_str:
+ # Convert bytes to base64-encoded string
+ output = base64.b64encode(output).decode("utf-8")
+
+ return output
+
+ @staticmethod
+ def deserialize(data):
+ """
+ Deserialize a previously serialized object.
+
+ Args:
+ data (bytes or str): The serialized data, optionally base64-encoded.
+
+ Returns:
+ The deserialized Python object.
+ """
+ if isinstance(data, str):
+ # Decode base64 string to bytes
+ data = base64.b64decode(data, validate=True)
+
+ return SafeUnpickler(io.BytesIO(data)).load()
+
+
+class SafeUnpickler(pickle.Unpickler):
+ ALLOWED_MODULE_PREFIXES = {
+ # --- Python types ---
+ "builtins.",
+ "collections.",
+ "copyreg.",
+ "functools.",
+ "itertools.",
+ "operator.",
+ "types.",
+ "weakref.",
+ # --- PyTorch types ---
+ "torch.",
+ "torch._tensor.",
+ "torch.storage.",
+ "torch.nn.parameter.",
+ "torch.autograd.function.",
+ # --- torch distributed ---
+ "torch.distributed.",
+ "torch.distributed._shard.",
+ "torch.distributed._composable.",
+ "torch._C._distributed_c10d.",
+ "torch._C._distributed_fsdp.",
+ "torch.distributed.optim.",
+ # --- multiprocessing ---
+ "multiprocessing.resource_sharer.",
+ "multiprocessing.reduction.",
+ "pickletools.",
+ # --- PEFT / LoRA ---
+ "peft.",
+ "transformers.",
+ "huggingface_hub.",
+ # --- SGLang & Unitest ---
+ "sglang.srt.weight_sync.tensor_bucket.",
+ "sglang.srt.model_executor.model_runner.",
+ "sglang.srt.layers.",
+ "sglang.srt.utils.",
+ # --- LightLLM ---
+ "lightllm.utils.",
+ }
+
+ DENY_CLASSES = {
+ ("builtins", "eval"),
+ ("builtins", "exec"),
+ ("builtins", "compile"),
+ ("os", "system"),
+ ("subprocess", "Popen"),
+ ("subprocess", "run"),
+ ("codecs", "decode"),
+ ("types", "CodeType"),
+ ("types", "FunctionType"),
+ }
+
+ def find_class(self, module, name):
+ # Block deterministic attacks
+ if (module, name) in self.DENY_CLASSES:
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+ # Allowlist of safe-to-load modules.
+ if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES):
+ return super().find_class(module, name)
+
+ # Block everything else. (Potential attack surface)
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+
+
+@dataclass
+class LocalSerializedTensor:
+ """torch.Tensor that gets serialized by MultiprocessingSerializer
+ (which only serializes a pointer and not the data).
+ The i-th element in the list corresponds to i-th rank's GPU."""
+
+ values: List[bytes]
+
+ def get(self, rank: int):
+ return MultiprocessingSerializer.deserialize(self.values[rank])
diff --git a/lightllm/utils/rl/tensor_bucket.py b/lightllm/utils/rl/tensor_bucket.py
new file mode 100644
index 0000000000..72defe9859
--- /dev/null
+++ b/lightllm/utils/rl/tensor_bucket.py
@@ -0,0 +1,111 @@
+"""
+Flattened tensor buckets used by RL weight-update requests.
+
+Some trainer-side integrations send many named tensors as one flattened byte
+tensor plus metadata. The server reconstructs the named tensors before passing
+them to model weight loading.
+
+Copied from:
+https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py
+"""
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import torch
+
+
+@dataclass
+class FlattenedTensorMetadata:
+ """Metadata for a tensor in a flattened bucket"""
+
+ name: str
+ shape: torch.Size
+ dtype: torch.dtype
+ start_idx: int
+ end_idx: int
+ numel: int
+
+
+class FlattenedTensorBucket:
+ """
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
+ while preserving all metadata needed for reconstruction.
+ """
+
+ # This field is solely for users of to check whether the class supports this feature
+ supports_multi_dtypes = True
+
+ def __init__(
+ self,
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
+ flattened_tensor: torch.Tensor = None,
+ metadata: List[FlattenedTensorMetadata] = None,
+ ):
+ """
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
+ Args:
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
+ metadata: Pre-computed metadata (for reconstruction)
+ """
+ if named_tensors is not None:
+ # Create bucket from named tensors
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
+ self.flattened_tensor: torch.Tensor = None
+
+ if not named_tensors:
+ raise ValueError("Cannot create empty tensor bucket")
+
+ # Collect metadata and flatten tensors
+ current_idx = 0
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
+
+ for i, (name, tensor) in enumerate(named_tensors):
+ flattened = tensor.flatten().view(torch.uint8)
+ flattened_tensors[i] = flattened
+
+ # Store metadata
+
+ numel = flattened.numel()
+ metadata_obj = FlattenedTensorMetadata(
+ name=name,
+ shape=tensor.shape,
+ dtype=tensor.dtype,
+ start_idx=current_idx,
+ end_idx=current_idx + numel,
+ numel=numel,
+ )
+ self.metadata[i] = metadata_obj
+ current_idx += numel
+
+ # Concatenate all flattened tensors
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
+ else:
+ # Initialize from pre-flattened data
+ if flattened_tensor is None or metadata is None:
+ raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata")
+ self.flattened_tensor = flattened_tensor
+ self.metadata = metadata
+
+ def get_flattened_tensor(self) -> torch.Tensor:
+ """Get the flattened tensor containing all bucket tensors"""
+ return self.flattened_tensor
+
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
+ """Get metadata for all tensors in the bucket"""
+ return self.metadata
+
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
+ """
+ Reconstruct original tensors from flattened tensor with optimized performance.
+ Uses memory-efficient operations to minimize allocations and copies.
+ """
+ # preallocate the result list
+ reconstructed = [None] * len(self.metadata)
+
+ for i, meta in enumerate(self.metadata):
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape)
+
+ reconstructed[i] = (meta.name, tensor)
+
+ return reconstructed
diff --git a/lightllm/utils/rl/torch_cuda_ipc.py b/lightllm/utils/rl/torch_cuda_ipc.py
new file mode 100644
index 0000000000..f2158c7a43
--- /dev/null
+++ b/lightllm/utils/rl/torch_cuda_ipc.py
@@ -0,0 +1,119 @@
+"""
+Patch torch multiprocessing CUDA tensor reductions to make cross-process
+CUDA IPC robust when different processes have different CUDA_VISIBLE_DEVICES.
+
+Torch serializes CUDA tensor IPC handles with a device index. That index is
+local to each process, so sender cuda:0 and receiver cuda:0 may refer to
+different physical GPUs. We replace the serialized device index with the GPU
+UUID on send, then map that UUID back to the receiver's local device index
+while rebuilding the tensor.
+
+The patch wraps torch's original reducers and only changes the device argument,
+so it avoids copying torch's CUDA IPC serialization implementation.
+
+Copied from:
+https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py
+"""
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import Callable, Union
+
+import torch
+from torch.multiprocessing import reductions
+
+
+def monkey_patch_torch_reductions():
+ """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
+
+ # Currently, NPU does not support UUID. This has been temporarily commented out,
+ # with support expected in the fourth quarter.
+ # if _is_npu:
+ # return
+
+ if hasattr(reductions, "_reduce_tensor_original"):
+ return
+
+ reductions._reduce_tensor_original = reductions.reduce_tensor
+ reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor
+
+ reductions.reduce_tensor = _reduce_tensor_modified
+ reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified
+
+ reductions.init_reductions()
+
+
+# The torch CUDA IPC rebuild signature has kept the device argument at this
+# index for years. Keep this constant in one place because both the global
+# monkey patch and local bucketed IPC rebuild path need to rewrite it.
+CUDA_IPC_REBUILD_DEVICE_ARG_INDEX = 6
+_rebuild_device_fallback: ContextVar[Union[int, None]] = ContextVar("rebuild_device_fallback", default=None)
+
+
+@contextmanager
+def cuda_rebuild_device_fallback(device: Union[int, None]):
+ token = _rebuild_device_fallback.set(device)
+ try:
+ yield
+ finally:
+ _rebuild_device_fallback.reset(token)
+
+
+def _reduce_tensor_modified(*args, **kwargs):
+ output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs)
+ output_args = _modify_tuple(output_args, CUDA_IPC_REBUILD_DEVICE_ARG_INDEX, cuda_device_to_uuid)
+ return output_fn, output_args
+
+
+def _rebuild_cuda_tensor_modified(*args):
+ args = _modify_tuple(args, CUDA_IPC_REBUILD_DEVICE_ARG_INDEX, cuda_device_from_maybe_uuid)
+ return reductions._rebuild_cuda_tensor_original(*args)
+
+
+def get_current_device_name() -> str:
+ if torch.cuda.is_available():
+ return "cuda"
+ return "cpu"
+
+
+def get_current_device_module():
+ device_name = get_current_device_name()
+ try:
+ return getattr(torch, device_name)
+ except AttributeError:
+ return torch.cuda
+
+
+def get_current_device_id() -> int:
+ return get_current_device_module().current_device()
+
+
+def cuda_device_to_uuid(device: int) -> str:
+ return str(torch.cuda.get_device_properties(device).uuid)
+
+
+def cuda_device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
+ if isinstance(device_maybe_uuid, int):
+ return device_maybe_uuid
+
+ if isinstance(device_maybe_uuid, str):
+ for device in range(torch.cuda.device_count()):
+ if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid:
+ return device
+ fallback_device = _rebuild_device_fallback.get()
+ if fallback_device is not None:
+ return fallback_device
+ raise Exception("Invalid device_uuid=" + device_maybe_uuid)
+
+ raise Exception(f"Unknown type: {device_maybe_uuid=}")
+
+
+def rebuild_cuda_ipc_tensor(handle: tuple[Callable, tuple], device_id: Union[int, None] = None) -> torch.Tensor:
+ func, args = handle
+ list_args = list(args)
+ if device_id is not None:
+ list_args[CUDA_IPC_REBUILD_DEVICE_ARG_INDEX] = device_id
+ return func(*list_args)
+
+
+def _modify_tuple(t, index: int, modifier: Callable):
+ return *t[:index], modifier(t[index]), *t[index + 1 :]
diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py
new file mode 100644
index 0000000000..90d1c9e6b1
--- /dev/null
+++ b/lightllm/utils/torch_memory_saver_utils.py
@@ -0,0 +1,94 @@
+import torch
+from contextlib import contextmanager
+from enum import Enum
+from lightllm.utils.log_utils import init_logger
+
+try:
+ from torch_memory_saver import (
+ torch_memory_saver,
+ configure_subprocess as tms_configure_subprocess,
+ )
+
+ HAS_TORCH_MEMORY_SAVER = True
+
+except ImportError:
+ HAS_TORCH_MEMORY_SAVER = False
+ pass
+
+logger = init_logger(__name__)
+
+
+class MemoryTag(Enum):
+ KV_CACHE = "kv_cache"
+ WEIGHT = "weights"
+ GRAPH = "graph"
+
+ def is_kv_cache(self):
+ return self == MemoryTag.KV_CACHE
+
+ def is_weight(self):
+ return self == MemoryTag.WEIGHT
+
+ def is_graph(self):
+ return self == MemoryTag.GRAPH
+
+ def __str__(self):
+ return self.value
+
+
+class TorchMemorySaverWrapper:
+ def __new__(cls, enable_torch_memory_saver: bool = False):
+ if enable_torch_memory_saver:
+ assert (
+ HAS_TORCH_MEMORY_SAVER
+ ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`."
+ return _TorchMemorySaver()
+ else:
+ return _TorchMemorySaverFake()
+
+
+class _TorchMemorySaver:
+ @contextmanager
+ def configure_subprocess(self):
+ with tms_configure_subprocess():
+ yield
+
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup)
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value)
+
+ def disable(self):
+ return torch_memory_saver.disable()
+
+ def pause(self, tag: MemoryTag):
+ return torch_memory_saver.pause(tag=tag.value)
+
+ def resume(self, tag: MemoryTag):
+ return torch_memory_saver.resume(tag=tag.value)
+
+
+class _TorchMemorySaverFake:
+ @contextmanager
+ def configure_subprocess(self):
+ yield
+
+ @contextmanager
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ yield
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch.cuda.graph(graph_obj, **kwargs)
+
+ @contextmanager
+ def disable(self):
+ yield
+
+ def pause(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, pause is not supported.")
+ return
+
+ def resume(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, resume is not supported.")
+ return
diff --git a/requirements.txt b/requirements.txt
index 603d0e488f..180f9d3ced 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -98,3 +98,4 @@ nixl==1.2.0
xformers==0.0.35
redis==7.3.0
litellm>=1.52.0,<1.85
+torch_memory_saver==0.0.9.post1
\ No newline at end of file
diff --git a/test/test_api/test_abort_chaos.py b/test/test_api/test_abort_chaos.py
new file mode 100644
index 0000000000..63f226717e
--- /dev/null
+++ b/test/test_api/test_abort_chaos.py
@@ -0,0 +1,105 @@
+"""
+Two-stage abort test against a running lightllm server.
+
+Stage 1: spawn N concurrent streams, then post /abort_request abort_all=True;
+ verify every stream terminates quickly.
+Stage 2: spawn N concurrent streams; each stream is independently assigned a
+ random fate (disconnect mid-stream or run to completion). The server
+ must keep serving the survivors and stay healthy afterwards.
+
+Usage:
+ python test/test_api/test_abort_chaos.py --url http://127.0.0.1:8000
+"""
+
+import argparse
+import asyncio
+import json
+import random
+import time
+from collections import Counter
+
+import httpx
+
+
+PROMPTS = [
+ "Write a long detailed essay about the history of computing.",
+ "Tell me a long story about dragons and knights.",
+ "Explain quantum mechanics in detail with lots of examples.",
+ "Describe the plot of a 5-part fantasy novel series.",
+ "Compose a long poem about the seasons.",
+]
+
+
+async def stream_task(client, url, mode, max_new_tokens):
+ payload = {
+ "inputs": random.choice(PROMPTS),
+ "parameters": {"max_new_tokens": max_new_tokens, "temperature": 0.7, "do_sample": True},
+ }
+ drop_after = random.randint(20, 200)
+ tokens = 0
+ finish_reason = None
+ t0 = time.time()
+ try:
+ async with client.stream("POST", f"{url}/generate_stream", json=payload, timeout=180.0) as r:
+ async for line in r.aiter_lines():
+ tokens += 1
+ if line.startswith("data:"):
+ chunk = json.loads(line[len("data:") :])
+ if chunk.get("finished"):
+ finish_reason = chunk.get("finish_reason")
+ if mode == "disconnect" and tokens >= drop_after:
+ break
+ return (mode, finish_reason or "ok", tokens, time.time() - t0)
+ except Exception as e:
+ return (mode, f"exc:{type(e).__name__}", tokens, time.time() - t0)
+
+
+def summarize(results):
+ outcomes = Counter()
+ for r in results:
+ outcomes[(r[0], r[1]) if isinstance(r, tuple) else f"raised:{type(r).__name__}"] += 1
+ for k, v in sorted(outcomes.items(), key=str):
+ print(f" {k}: {v}")
+
+
+async def stage_abort_all(client, url, concurrency, max_new_tokens):
+ print("\n===== STAGE 1: abort_all on N concurrent streams =====")
+ tasks = [asyncio.create_task(stream_task(client, url, "finish", max_new_tokens)) for _ in range(concurrency)]
+ await asyncio.sleep(2.0)
+ t0 = time.time()
+ r = await client.post(f"{url}/abort_request", json={"abort_all": True}, timeout=10.0)
+ print(f"abort_all status={r.status_code}")
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ print(f"all streams settled in {time.time() - t0:.2f}s")
+ summarize(results)
+
+
+async def stage_random_chaos(client, url, concurrency, max_new_tokens):
+ print("\n===== STAGE 2: random per-stream chaos =====")
+ modes = random.choices(["disconnect", "finish"], weights=[80, 20], k=concurrency)
+ tasks = [asyncio.create_task(stream_task(client, url, m, max_new_tokens)) for m in modes]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ summarize(results)
+
+
+async def run(url, concurrency, max_new_tokens):
+ async with httpx.AsyncClient() as client:
+ await stage_abort_all(client, url, concurrency, max_new_tokens)
+ await stage_random_chaos(client, url, concurrency, max_new_tokens)
+ print("\nALL CHAOS TESTS PASSED")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--url", default="http://127.0.0.1:8000")
+ parser.add_argument("--concurrency", type=int, default=24)
+ parser.add_argument("--max_new_tokens", type=int, default=2048)
+ parser.add_argument("--seed", type=int, default=42)
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+ asyncio.run(run(args.url, args.concurrency, args.max_new_tokens))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/test_api/test_abort_request.py b/test/test_api/test_abort_request.py
new file mode 100644
index 0000000000..ca99f6298c
--- /dev/null
+++ b/test/test_api/test_abort_request.py
@@ -0,0 +1,432 @@
+"""
+Test the /abort_request endpoint against a running lightllm server.
+
+What this test asserts (and why it does not assert "stream becomes finish_reason='abort'"):
+
+ In normal / chunked_prefill mode, /abort_request:
+ - sets shm_req.is_aborted = True
+ - drives the router to send AbortedReqCmd, which sets InferReq.infer_aborted = True
+ on the worker
+ - causes still-waiting (not yet scheduled) reqs to be freed with FINISHED_ABORTED
+ - but does NOT cause already-running reqs to early-exit; they finish at max_new_tokens
+ / EOS / stop sequence as usual. (The shm flag is consumed by audio/visual servers
+ and pd_nixl mode, but the LLM inference loop never short-circuits on it.)
+
+So the test verifies the contract that actually exists today:
+
+ Stage A: bogus request_id -> HTTP 200, server log "not exist" warning
+ Stage B: abort_all on an idle server -> HTTP 200, no errors
+ Stage C: abort_all on a running stream
+ -> HTTP 200; server log shows "aborted group_request_id N" warning
+ -> the stream terminates within reasonable time (whether via abort or
+ natural max_new_tokens completion)
+ Stage D: abort by SPECIFIC request_id on a running stream
+ -> resolve the lightllm_req_id from the server log (via X-Request-Id),
+ POST /abort_request with that exact id, verify the targeted log
+ warning lands and the stream terminates
+ Stage E: server remains healthy and answers a fresh /generate
+
+Usage:
+ python test/test_api/test_abort_request.py \
+ --url http://127.0.0.1:8000 \
+ --server_log_path /tmp/lightllm_test/server.log
+"""
+
+import argparse
+import json
+import os
+import re
+import sys
+import threading
+import time
+import uuid
+from typing import List, Optional, Tuple
+
+import requests
+
+
+GREEN = "\033[32m"
+RED = "\033[31m"
+YELLOW = "\033[33m"
+RESET = "\033[0m"
+
+
+def banner(msg: str):
+ print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True)
+
+
+def ok(msg: str):
+ print(f" {GREEN}OK{RESET} {msg}", flush=True)
+
+
+def fail(msg: str):
+ print(f" {RED}FAIL{RESET} {msg}", flush=True)
+
+
+# ---------------- HTTP helpers ----------------
+
+
+def _get_health(url: str, timeout=5):
+ return requests.get(url + "/health", timeout=timeout)
+
+
+def post_abort(url: str, request_id: Optional[int] = None, abort_all: bool = False) -> Tuple[int, str]:
+ payload = {"abort_all": abort_all}
+ if request_id is not None:
+ payload["request_id"] = request_id
+ r = requests.post(url + "/abort_request", json=payload, timeout=30)
+ return r.status_code, r.text
+
+
+# ---------------- streaming helpers ----------------
+
+
+def _stream_run(
+ url: str,
+ prompt: str,
+ max_new_tokens: int,
+ x_request_id: str,
+ out: dict,
+ close_after_n: Optional[int] = None,
+):
+ """
+ Issue a /generate_stream and append every event to out["events"].
+ If close_after_n is set, the underlying socket is forcibly closed
+ (TCP RST via SO_LINGER + close) after that many events arrive — kept
+ here for completeness even though no current stage uses it. Sets
+ out["error"] on transport errors.
+ """
+ headers = {"X-Request-Id": x_request_id, "Content-Type": "application/json"}
+ body = {
+ "inputs": prompt,
+ "parameters": {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": False,
+ "ignore_eos": True,
+ },
+ }
+ out["events"] = []
+ out["start"] = time.time()
+ out["error"] = None
+ out["closed_intentionally"] = False
+ try:
+ # urllib3 keeps the socket pooled; we need direct access to force-close.
+ with requests.post(url + "/generate_stream", json=body, headers=headers, stream=True, timeout=120) as r:
+ r.raise_for_status()
+ for raw in r.iter_lines(decode_unicode=True):
+ if not raw:
+ continue
+ if raw.startswith("data:"):
+ raw = raw[len("data:") :]
+ try:
+ ev = json.loads(raw)
+ except Exception:
+ continue
+ ev["_t"] = time.time() - out["start"]
+ out["events"].append(ev)
+ if close_after_n is not None and len(out["events"]) >= close_after_n:
+ out["closed_intentionally"] = True
+ # Reach into urllib3 to force a TCP RST so the server sees
+ # the disconnect immediately rather than after a graceful
+ # FIN that hypercorn might not propagate while the response
+ # is mid-stream.
+ try:
+ import socket as _socket
+
+ sock = r.raw._fp.fp.raw._sock # type: ignore[attr-defined]
+ # SO_LINGER with timeout 0 -> RST on close.
+ l_onoff, l_linger = 1, 0
+ sock.setsockopt(
+ _socket.SOL_SOCKET,
+ _socket.SO_LINGER,
+ int.to_bytes(l_onoff, 4, "little") + int.to_bytes(l_linger, 4, "little"),
+ )
+ sock.close()
+ except Exception as e:
+ out["close_error"] = repr(e)
+ break
+ if ev.get("finished"):
+ break
+ except Exception as e:
+ out["error"] = repr(e)
+ out["end"] = time.time()
+
+
+def start_stream(
+ url: str, prompt: str, max_new_tokens: int, close_after_n: Optional[int] = None
+) -> Tuple[threading.Thread, dict, str]:
+ xid = uuid.uuid4().hex
+ out = {}
+ th = threading.Thread(target=_stream_run, args=(url, prompt, max_new_tokens, xid, out, close_after_n))
+ th.daemon = True
+ th.start()
+ return th, out, xid
+
+
+def wait_for_first_token(out: dict, timeout: float = 30.0) -> bool:
+ deadline = time.time() + timeout
+ while time.time() < deadline:
+ if out.get("events"):
+ return True
+ time.sleep(0.05)
+ return False
+
+
+def get_finish_reason(out: dict) -> Optional[str]:
+ for ev in reversed(out.get("events") or []):
+ fr = ev.get("finish_reason")
+ if fr:
+ return fr
+ return None
+
+
+# ---------------- log helpers ----------------
+
+
+def _read_log_tail(server_log_path: Optional[str], max_bytes: int = 256 * 1024) -> str:
+ if not server_log_path or not os.path.exists(server_log_path):
+ return ""
+ try:
+ size = os.path.getsize(server_log_path)
+ with open(server_log_path, "rb") as f:
+ if size > max_bytes:
+ f.seek(size - max_bytes)
+ return f.read().decode("utf-8", errors="ignore")
+ except FileNotFoundError:
+ return ""
+
+
+def grep_log_for_pattern(server_log_path: Optional[str], pattern: re.Pattern, timeout: float = 5.0) -> Optional[str]:
+ """Poll the tail of the server log for a regex match."""
+ deadline = time.time() + timeout
+ while time.time() < deadline:
+ tail = _read_log_tail(server_log_path)
+ m = pattern.search(tail)
+ if m:
+ return m.group(0)
+ time.sleep(0.1)
+ return None
+
+
+def grep_log_after_offset(
+ server_log_path: Optional[str], start_offset: int, pattern: re.Pattern, timeout: float = 5.0
+) -> Optional[str]:
+ """Poll the server log starting at start_offset for a regex match.
+ Only content written after start_offset is considered, so this isolates
+ a stage from log produced by earlier stages."""
+ if not server_log_path:
+ return None
+ deadline = time.time() + timeout
+ while time.time() < deadline:
+ try:
+ with open(server_log_path, "rb") as f:
+ f.seek(start_offset)
+ new = f.read().decode("utf-8", errors="ignore")
+ except FileNotFoundError:
+ new = ""
+ m = pattern.search(new)
+ if m:
+ return m.group(0)
+ time.sleep(0.1)
+ return None
+
+
+def server_log_size(server_log_path: Optional[str]) -> int:
+ if not server_log_path or not os.path.exists(server_log_path):
+ return 0
+ return os.path.getsize(server_log_path)
+
+
+def lookup_lightllm_req_id_from_log(server_log_path: str, x_request_id: str, timeout: float = 5.0) -> Optional[int]:
+ pattern = re.compile(rf"received req X-Request-Id:{re.escape(x_request_id)}\b.*?lightllm_req_id:(\d+)")
+ deadline = time.time() + timeout
+ while time.time() < deadline:
+ tail = _read_log_tail(server_log_path)
+ m = pattern.search(tail)
+ if m:
+ return int(m.group(1))
+ time.sleep(0.1)
+ return None
+
+
+# ---------------- stages ----------------
+
+
+def stage_a_bogus_id(url: str) -> bool:
+ banner("Stage A: abort with a non-existent id")
+ bogus = 99_999_999
+ code, text = post_abort(url, request_id=bogus, abort_all=False)
+ print(f" /abort_request request_id={bogus} -> HTTP {code} body={text!r}")
+ if code != 200:
+ fail(f"expected HTTP 200, got {code}")
+ return False
+ ok("HTTP 200")
+ return True
+
+
+def stage_b_abort_all_idle(url: str) -> bool:
+ banner("Stage B: abort_all on an idle server")
+ code, text = post_abort(url, abort_all=True)
+ print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}")
+ if code != 200:
+ fail(f"expected HTTP 200, got {code}")
+ return False
+ ok("HTTP 200")
+ return True
+
+
+def stage_c_abort_running(url: str, server_log_path: Optional[str]) -> bool:
+ banner("Stage C: abort_all on a running stream")
+ log_offset = server_log_size(server_log_path)
+ th, out, xid = start_stream(url, "Recite the alphabet repeatedly.", max_new_tokens=200)
+ if not wait_for_first_token(out, timeout=30.0):
+ fail("did not receive any tokens before abort")
+ return False
+ first_t = out["events"][0]["_t"]
+ ok(f"first token at +{first_t:.2f}s")
+
+ target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0) if server_log_path else None
+ print(f" resolved lightllm_req_id from log: {target_id}")
+
+ code, text = post_abort(url, abort_all=True)
+ print(f" /abort_request abort_all=true -> HTTP {code} body={text!r}")
+ if code != 200:
+ fail(f"expected HTTP 200, got {code}")
+ return False
+
+ th.join(timeout=60.0)
+ if th.is_alive():
+ fail("stream did not terminate within 60s of abort")
+ return False
+ fr = get_finish_reason(out)
+ n = len(out.get("events") or [])
+ print(f" stream events received: {n}, finish_reason={fr!r}, error={out.get('error')!r}")
+
+ # The api itself succeeded; whether the stream got a clean 'abort' finish reason
+ # depends on which mode-backend the server is running. We DO assert the abort
+ # warning landed in the server log though, scoped to log content produced after
+ # this stage started so we don't match earlier-stage residue.
+ if server_log_path:
+ if target_id is not None:
+ pat = re.compile(rf"aborted group_request_id {target_id}\b")
+ else:
+ pat = re.compile(r"aborted group_request_id \d+")
+ hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0)
+ if not hit:
+ fail("could not find 'aborted group_request_id' in server log (post-stage)")
+ return False
+ ok(f"server log recorded: {hit!r}")
+ else:
+ print(" no --server_log_path; skipped log assertion")
+ ok("stream terminated and abort acknowledged")
+ return True
+
+
+def stage_d_abort_by_id(url: str, server_log_path: Optional[str]) -> bool:
+ banner("Stage D: abort by specific request_id on a running stream")
+ if not server_log_path:
+ print(" --server_log_path not provided; skipping (we need the log to resolve req_id)")
+ return True
+
+ log_offset = server_log_size(server_log_path)
+ th, out, xid = start_stream(url, "Sing a long lullaby for the moon.", max_new_tokens=300)
+ if not wait_for_first_token(out, timeout=30.0):
+ fail("did not receive any tokens before abort")
+ return False
+ ok(f"first token at +{out['events'][0]['_t']:.2f}s, X-Request-Id={xid[:8]}…")
+
+ target_id = lookup_lightllm_req_id_from_log(server_log_path, xid, timeout=5.0)
+ if target_id is None:
+ fail("could not resolve lightllm_req_id from server log; cannot test by-id abort")
+ return False
+ print(f" resolved lightllm_req_id: {target_id}")
+
+ code, text = post_abort(url, request_id=target_id, abort_all=False)
+ print(f" /abort_request request_id={target_id} -> HTTP {code} body={text!r}")
+ if code != 200:
+ fail(f"expected HTTP 200, got {code}")
+ return False
+
+ th.join(timeout=60.0)
+ if th.is_alive():
+ fail("stream did not terminate within 60s")
+ return False
+ fr = get_finish_reason(out)
+ n = len(out.get("events") or [])
+ print(f" stream events received: {n}, finish_reason={fr!r}")
+
+ pat = re.compile(rf"aborted group_request_id {target_id}\b")
+ hit = grep_log_after_offset(server_log_path, log_offset, pat, timeout=5.0)
+ if not hit:
+ fail(f"could not find 'aborted group_request_id {target_id}' in server log (post-stage)")
+ return False
+ ok(f"server log recorded: {hit!r}")
+ return True
+
+
+def stage_e_health_after(url: str) -> bool:
+ banner("Stage E: server still serves a normal /generate")
+ r = requests.post(
+ url + "/generate",
+ json={
+ "inputs": "The capital of France is",
+ "parameters": {"max_new_tokens": 6, "do_sample": False},
+ },
+ timeout=60,
+ )
+ print(f" /generate -> HTTP {r.status_code} {r.text[:200]}")
+ if r.status_code != 200:
+ fail(f"final /generate failed with {r.status_code}")
+ return False
+ body = r.json()
+ text = body.get("generated_text")
+ if isinstance(text, list):
+ text = text[0]
+ if not text or not text.strip():
+ fail("final /generate returned empty text")
+ return False
+ ok(f"final /generate returned {text!r}")
+ return True
+
+
+# ---------------- main ----------------
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--url", default="http://127.0.0.1:8000")
+ ap.add_argument(
+ "--server_log_path",
+ default=None,
+ help="optional path to the server stdout/stderr log; enables log-grep assertions",
+ )
+ args = ap.parse_args()
+
+ try:
+ r = _get_health(args.url)
+ r.raise_for_status()
+ except Exception as e:
+ fail(f"server at {args.url} not reachable: {e}")
+ sys.exit(1)
+ ok(f"server reachable at {args.url}")
+
+ results = []
+ results.append(("A", stage_a_bogus_id(args.url)))
+ results.append(("B", stage_b_abort_all_idle(args.url)))
+ results.append(("C", stage_c_abort_running(args.url, args.server_log_path)))
+ results.append(("D", stage_d_abort_by_id(args.url, args.server_log_path)))
+ results.append(("E", stage_e_health_after(args.url)))
+
+ print("\n" + "=" * 50)
+ all_ok = True
+ for name, passed in results:
+ tag = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}"
+ print(f" Stage {name}: {tag}")
+ all_ok = all_ok and passed
+ if not all_ok:
+ sys.exit(1)
+ print(f"\n{GREEN}ALL ABORT STAGES PASSED{RESET}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py
new file mode 100644
index 0000000000..85c4e44ef9
--- /dev/null
+++ b/test/test_api/test_r3.py
@@ -0,0 +1,92 @@
+import sys
+import argparse
+import requests
+import base64
+import numpy as np
+
+
+def test_routing_export(url: str = "http://localhost:8000"):
+ print(f"Testing routing export at {url}")
+ print("-" * 50)
+
+ try:
+ response = requests.post(
+ f"{url}/generate",
+ json={
+ "inputs": "What is the capital of France? What is the capital of France?",
+ "parameters": {
+ "max_new_tokens": 50,
+ # "return_routed_experts": True,
+ # "repetition_penalty": 1.0,
+ },
+ },
+ timeout=60,
+ )
+ except requests.exceptions.ConnectionError:
+ print(f"ERROR: Cannot connect to server at {url}")
+ print("Make sure the LightLLM server is running with --enable_return_routed_experts")
+ return False
+ except requests.exceptions.Timeout:
+ print("ERROR: Request timed out")
+ return False
+
+ print(f"Status: {response.status_code}")
+
+ if response.status_code != 200:
+ print(f"ERROR: Request failed with status {response.status_code}")
+ print(f"Response: {response.text}")
+ return False
+
+ res = response.json()
+ print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...")
+
+ if "routed_experts" not in res or not res["routed_experts"]:
+ print("\nWARNING: No routed_experts in response.")
+ print("This could mean:")
+ print(" - The model is not a MoE model")
+ print(" - The server was not started with --enable_return_routed_experts")
+ print(" - The routing capture manager was not initialized")
+ return False
+
+ routing_info = res["routed_experts"]
+ shape = routing_info["shape"]
+ dtype_str = routing_info["dtype"]
+ dtype = np.dtype(dtype_str)
+ data = base64.b64decode(routing_info["data"])
+ routing_array = np.frombuffer(data, dtype=dtype).reshape(shape)
+
+ print(f"\n{'=' * 50}")
+ print("ROUTING CAPTURE SUCCESS!")
+ print(f"{'=' * 50}")
+ print(f"Shape: {shape}")
+ print(f"Dtype: {dtype}")
+ print(f"Num tokens: {shape[0]}")
+ print(f"Num MoE layers: {shape[1]}")
+ print(f"Top-K: {shape[2]}")
+
+ # Compute payload size savings
+ int32_size = np.prod(shape) * 4
+ actual_size = len(data)
+ savings = (1 - actual_size / int32_size) * 100
+ print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)")
+
+ print(f"\nSample routing (first layer, first 5 tokens):")
+ num_tokens_to_show = shape[0]
+ for i in range(num_tokens_to_show):
+ print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}")
+
+ if np.all(routing_array == 0):
+ print("\nWARNING: All routing data is zeros. Capture may not be working correctly.")
+ return False
+
+ print("\nTest PASSED!")
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Test R3 routing export feature")
+ parser.add_argument("--url", default="http://localhost:8000", help="Server URL")
+ args = parser.parse_args()
+
+ success = test_routing_export(args.url)
+ sys.exit(0 if success else 1)
diff --git a/test/test_api/test_rl_endpoints.py b/test/test_api/test_rl_endpoints.py
new file mode 100644
index 0000000000..91829530ec
--- /dev/null
+++ b/test/test_api/test_rl_endpoints.py
@@ -0,0 +1,349 @@
+"""
+Test release_memory_occupation / resume_memory_occupation / update_weights_from_tensor
+against a running lightllm server.
+
+Sequence:
+ 1. baseline generate (sanity)
+ 2. release_memory_occupation -> GPU memory should drop sharply
+ 3. resume_memory_occupation -> GPU memory should grow back
+ (without --enable_weight_cpu_backup the weight
+ memory is allocated empty, so generation right
+ after resume is expected to be garbage)
+ 4. update_weights_from_tensor (per-batch CUDA-IPC handoff) for every parameter
+ found on disk -> repopulate weights
+ 5. final generate -> should produce a sensible answer again
+
+The "trainer" runs in this same process: it holds tensors on a free GPU, serialises
+them via lightllm.utils.rl.serialization.MultiprocessingSerializer (CUDA IPC handles, not
+data), then asks the server to clone them into its weight buffers. No NCCL group
+is required, so this is safe to interrupt without leaving the server hung.
+
+Usage:
+ python test/test_api/test_rl_endpoints.py \
+ --url http://127.0.0.1:8000 \
+ --model_dir /nvme/models/Qwen3.5-35B-A3B \
+ --tp 4 \
+ --server_devices 0,1,2,3 \
+ --client_device 4
+
+Notes:
+ - This script must run on the same machine as the server (CUDA IPC).
+ - --server_devices are nvidia-smi GPU indices for the TP workers. If omitted,
+ the script infers them from the top --tp memory consumers before release.
+ - --client_device picks a free CUDA device for the in-process trainer; it is
+ independent from --server_devices and should not overlap the TP workers.
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+import time
+from glob import glob
+from typing import Dict, List, Tuple
+
+# Make the repo importable when this script is invoked by path rather than -m.
+_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if _REPO_ROOT not in sys.path:
+ sys.path.insert(0, _REPO_ROOT)
+
+import requests
+import torch
+from safetensors import safe_open
+
+from lightllm.utils.rl.serialization import MultiprocessingSerializer
+from lightllm.utils.rl.torch_cuda_ipc import monkey_patch_torch_reductions
+
+
+GREEN = "\033[32m"
+RED = "\033[31m"
+YELLOW = "\033[33m"
+RESET = "\033[0m"
+
+
+def banner(msg: str):
+ print(f"\n{YELLOW}=== {msg} ==={RESET}", flush=True)
+
+
+def ok(msg: str):
+ print(f" {GREEN}OK{RESET} {msg}", flush=True)
+
+
+def fail(msg: str):
+ print(f" {RED}FAIL{RESET} {msg}", flush=True)
+
+
+def gpu_mem_used_mib() -> List[int]:
+ out = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"]).decode()
+ return [int(x.strip()) for x in out.strip().splitlines()]
+
+
+def _select_gpu_mem(mem: List[int], devices: List[int]) -> List[int]:
+ return [mem[i] for i in devices]
+
+
+def _resolve_server_devices(server_devices: str, tp: int, mem: List[int]) -> List[int]:
+ if tp <= 0:
+ raise ValueError(f"--tp must be positive, got {tp}")
+ if tp > len(mem):
+ raise ValueError(f"--tp={tp} but nvidia-smi only returned {len(mem)} GPUs")
+
+ value = server_devices.strip()
+ if value.lower() == "auto":
+ return sorted(range(len(mem)), key=lambda i: mem[i], reverse=True)[:tp]
+
+ devices = [int(x.strip()) for x in value.split(",") if x.strip()]
+ if len(devices) != tp:
+ raise ValueError(f"--server_devices must contain exactly --tp entries; got {devices} for tp={tp}")
+ if len(set(devices)) != len(devices):
+ raise ValueError(f"--server_devices contains duplicates: {devices}")
+
+ bad = [i for i in devices if i < 0 or i >= len(mem)]
+ if bad:
+ raise ValueError(f"--server_devices contains invalid GPU indices {bad}; nvidia-smi returned {len(mem)} GPUs")
+ return devices
+
+
+def post(url: str, path: str, payload=None, timeout=600):
+ r = requests.post(url + path, json=payload or {}, timeout=timeout)
+ try:
+ body = r.json()
+ except Exception:
+ body = r.text
+ return r.status_code, body
+
+
+def generate(url: str, prompt: str, max_new_tokens: int = 16) -> str:
+ r = requests.post(
+ url + "/generate",
+ json={
+ "inputs": prompt,
+ "parameters": {"max_new_tokens": max_new_tokens, "do_sample": False},
+ },
+ timeout=120,
+ )
+ r.raise_for_status()
+ data = r.json()
+ if isinstance(data.get("generated_text"), list):
+ return data["generated_text"][0]
+ return data.get("generated_text", json.dumps(data))
+
+
+def looks_garbage(text: str) -> bool:
+ """Heuristic: post-resume text is usually a single repeated character (e.g. '!!!!')."""
+ s = text.strip()
+ if not s:
+ return True
+ return len(set(s)) == 1
+
+
+# ---------------- weight-update helpers (update_weights_from_tensor) ----------------
+
+
+def _list_safetensor_shards(model_dir: str) -> List[str]:
+ shards = sorted(glob(os.path.join(model_dir, "*.safetensors")))
+ if not shards:
+ raise RuntimeError(f"no .safetensors found under {model_dir}")
+ return shards
+
+
+def _send_update_from_tensor(
+ url: str,
+ serialized_per_rank: List[str],
+ flush_cache: bool = False,
+):
+ code, body = post(
+ url,
+ "/update_weights_from_tensor",
+ {
+ "serialized_named_tensors": serialized_per_rank,
+ "load_format": None,
+ "flush_cache": flush_cache,
+ "abort_all_requests": False,
+ },
+ timeout=600,
+ )
+ return code, body
+
+
+def update_weights_from_disk_via_tensor_api(
+ url: str,
+ model_dir: str,
+ tp: int,
+ client_device: int,
+ batch_per_request: int = 8,
+ flush_cache_at_end: bool = True,
+):
+ """
+ Acts as an in-process "trainer": loads every safetensor shard onto
+ cuda:client_device, then ships each batch of (name, tensor) to the server
+ via /update_weights_from_tensor. The server worker on each TP rank receives
+ a CUDA IPC handle, copies into its weight buffer.
+ """
+ banner("update_weights_from_tensor (CUDA IPC)")
+ # Server side patches its own copy; we patch ours so reductions can serialise
+ # CUDA tensors with UUID-based device addressing.
+ monkey_patch_torch_reductions()
+ torch.cuda.set_device(client_device)
+ device = f"cuda:{client_device}"
+
+ shards = _list_safetensor_shards(model_dir)
+ print(f" found {len(shards)} safetensor shards, batch_per_request={batch_per_request}", flush=True)
+
+ total_params = 0
+ total_bytes = 0
+ t0 = time.time()
+ for shard_idx, shard in enumerate(shards):
+ shard_t0 = time.time()
+ with safe_open(shard, framework="pt") as f:
+ keys = list(f.keys())
+ for i in range(0, len(keys), batch_per_request):
+ batch_keys = keys[i : i + batch_per_request]
+ # Load batch onto the client GPU. .contiguous() guarantees a
+ # whole-tensor allocation (safetensors slices are already
+ # contiguous, but this is cheap insurance).
+ tensors = [f.get_tensor(k).to(device).contiguous() for k in batch_keys]
+ named: List[Tuple[str, torch.Tensor]] = list(zip(batch_keys, tensors))
+
+ # Same payload to every TP rank — the server clones full
+ # tensors per rank and lets model.load_weights handle the TP
+ # sharding internally (matching how update_weights_from_*
+ # paths are written).
+ blob = MultiprocessingSerializer.serialize(named, output_str=True)
+ serialized_per_rank = [blob] * tp
+
+ # Last batch flushes the prefix cache so old KV from the
+ # previous weight version cannot poison subsequent gens.
+ is_last = (shard_idx == len(shards) - 1) and (i + batch_per_request >= len(keys))
+ code, body = _send_update_from_tensor(
+ url,
+ serialized_per_rank,
+ flush_cache=(flush_cache_at_end and is_last),
+ )
+ if code != 200:
+ fail(f"update batch failed: {code} {body}")
+ raise RuntimeError(f"update batch failed: {code} {body}")
+ total_params += len(batch_keys)
+ total_bytes += sum(t.numel() * t.element_size() for t in tensors)
+ # Free client-side memory before next batch — the worker has
+ # already cloned the data by the time post() returned.
+ for t in tensors:
+ del t
+ del tensors, named
+ torch.cuda.empty_cache()
+
+ print(
+ f" shard {shard_idx+1}/{len(shards)} done "
+ f"(+{len(keys)} tensors, {time.time()-shard_t0:.1f}s, "
+ f"running total {total_params} params, {total_bytes/1e9:.1f} GB)",
+ flush=True,
+ )
+
+ dt = time.time() - t0
+ ok(f"streamed {total_params} params, {total_bytes/1e9:.1f} GB in {dt:.1f}s")
+
+
+# ---------------- main flow ----------------
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--url", default="http://127.0.0.1:8000")
+ ap.add_argument("--model_dir", required=True)
+ ap.add_argument("--tp", type=int, required=True)
+ ap.add_argument(
+ "--server_devices",
+ default="auto",
+ help="comma-separated nvidia-smi GPU indices used by the server, or 'auto' to infer from memory usage",
+ )
+ ap.add_argument(
+ "--client_device",
+ type=int,
+ default=2,
+ help="GPU index for the in-process trainer; must differ from TP worker GPUs",
+ )
+ ap.add_argument("--prompt", default="The capital of France is")
+ ap.add_argument("--max_new_tokens", type=int, default=16)
+ ap.add_argument("--batch_per_request", type=int, default=8)
+ ap.add_argument("--skip_update", action="store_true", help="run only release/resume, skip the update_weights phase")
+ args = ap.parse_args()
+
+ # ---------------- stage 1: baseline ----------------
+ banner("baseline generate")
+ base_text = generate(args.url, args.prompt, args.max_new_tokens)
+ print(f" prompt : {args.prompt!r}")
+ print(f" generated: {base_text!r}")
+ ok("baseline generated")
+
+ # ---------------- stage 2: release ----------------
+ banner("release_memory_occupation")
+ before = gpu_mem_used_mib()
+ try:
+ server_devices = _resolve_server_devices(args.server_devices, args.tp, before)
+ except ValueError as e:
+ fail(str(e))
+ sys.exit(1)
+ print(f" server GPUs : {server_devices}")
+ print(f" GPU mem before: {_select_gpu_mem(before, server_devices)}")
+ code, body = post(args.url, "/release_memory_occupation", {})
+ print(f" resp: {code} {body}")
+ if code != 200:
+ fail("release failed")
+ sys.exit(1)
+ time.sleep(2)
+ after = gpu_mem_used_mib()
+ print(f" GPU mem after : {_select_gpu_mem(after, server_devices)}")
+ drop = sum(_select_gpu_mem(before, server_devices)) - sum(_select_gpu_mem(after, server_devices))
+ if drop < 10_000:
+ fail(f"release did not free much memory (delta={drop} MiB)")
+ sys.exit(1)
+ ok(f"release freed ~{drop} MiB on TP GPUs")
+
+ # ---------------- stage 3: resume ----------------
+ banner("resume_memory_occupation")
+ code, body = post(args.url, "/resume_memory_occupation", {})
+ print(f" resp: {code} {body}")
+ if code != 200:
+ fail("resume failed")
+ sys.exit(1)
+ time.sleep(2)
+ print(f" GPU mem after : {_select_gpu_mem(gpu_mem_used_mib(), server_devices)}")
+ ok("resume returned success")
+
+ banner("post-resume generate (likely garbage without weight cpu backup)")
+ text_after_resume = generate(args.url, args.prompt, args.max_new_tokens)
+ print(f" generated: {text_after_resume!r} garbage_heuristic={looks_garbage(text_after_resume)}")
+
+ if args.skip_update:
+ ok("done (skipped update_weights stage)")
+ return
+
+ # ---------------- stage 4: update_weights_from_tensor ----------------
+ update_weights_from_disk_via_tensor_api(
+ url=args.url,
+ model_dir=args.model_dir,
+ tp=args.tp,
+ client_device=args.client_device,
+ batch_per_request=args.batch_per_request,
+ flush_cache_at_end=True,
+ )
+
+ # ---------------- stage 5: final generate ----------------
+ banner("final generate (after weight reload)")
+ final_text = generate(args.url, args.prompt, args.max_new_tokens)
+ print(f" prompt : {args.prompt!r}")
+ print(f" generated: {final_text!r}")
+ if looks_garbage(final_text):
+ fail("final generation still looks like garbage; weight update did not stick")
+ sys.exit(1)
+ if final_text.strip() == base_text.strip():
+ ok("final output matches baseline exactly")
+ else:
+ ok("final output is sensible (differs from baseline but not garbage)")
+
+ print(f"\n{GREEN}ALL STAGES PASSED{RESET}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py
new file mode 100644
index 0000000000..1729187856
--- /dev/null
+++ b/unit_tests/common/basemodel/test_routing_capture_manager.py
@@ -0,0 +1,265 @@
+import numpy as np
+import pytest
+import torch
+
+from lightllm.common.basemodel.infer_struct import InferStateInfo
+from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+
+def _skip_without_cuda():
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for routing capture.")
+
+
+class TestRoutingCaptureManager:
+ def test_capture_and_extract_basic(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=64,
+ )
+ mem_indexes = torch.arange(100, 110, device="cuda")
+ expected = np.zeros((10, 4, 8), dtype=np.uint8)
+
+ make_capture_callback = manager.make_capture_callback_factory(mem_indexes)
+ for layer_idx in range(4):
+ topk_ids = torch.randint(0, 64, (10, 8), device="cuda")
+ make_capture_callback(layer_idx)(topk_ids)
+ expected[:, layer_idx, :] = topk_ids.cpu().numpy().astype(np.uint8)
+
+ result = manager.extract_routing_data(mem_indexes)
+ assert result.shape == (10, 4, 8)
+ assert result.dtype == np.uint8
+ np.testing.assert_array_equal(result, expected)
+
+ def test_capture_writes_to_correct_kv_positions(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=2,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+ mem_indexes = torch.tensor([10, 50, 200], device="cuda")
+ topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda")
+ topk_ids_layer1 = topk_ids + 20
+
+ make_capture_callback = manager.make_capture_callback_factory(mem_indexes)
+ make_capture_callback(0)(topk_ids)
+ make_capture_callback(1)(topk_ids_layer1)
+
+ result = manager.extract_routing_data(mem_indexes)
+ np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8))
+ np.testing.assert_array_equal(result[:, 1, :], topk_ids_layer1.cpu().numpy().astype(np.uint8))
+
+ def test_capture_maps_transformer_layer_num_to_routing_slot(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=2,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ layer_num_to_moe_index={3: 0, 7: 1},
+ )
+ mem_indexes = torch.tensor([10, 11], device="cuda")
+ ids_layer3 = torch.tensor([[1, 2], [3, 4]], device="cuda")
+ ids_layer7 = torch.tensor([[5, 6], [7, 8]], device="cuda")
+
+ make_capture_callback = manager.make_capture_callback_factory(mem_indexes)
+ make_capture_callback(3)(ids_layer3)
+ make_capture_callback(7)(ids_layer7)
+ assert make_capture_callback(4) is None
+
+ result = manager.extract_routing_data(mem_indexes)
+ np.testing.assert_array_equal(result[:, 0, :], ids_layer3.cpu().numpy().astype(np.uint8))
+ np.testing.assert_array_equal(result[:, 1, :], ids_layer7.cpu().numpy().astype(np.uint8))
+
+ def test_capture_rejects_unexpected_topk_width(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+ mem_indexes = torch.tensor([10, 11], device="cuda")
+ topk_ids_with_shared_expert = torch.tensor([[1, 2, 32], [3, 4, 32]], device="cuda")
+
+ with pytest.raises(AssertionError):
+ manager.make_capture_callback_factory(mem_indexes)(0)(topk_ids_with_shared_expert)
+
+ def test_cuda_graph_replay_uses_copied_mem_indexes(self):
+ _skip_without_cuda()
+
+ class _NoopDecodeState:
+ def copy_for_decode_cuda_graph(self, other):
+ return
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+ graph_infer_state = InferStateInfo()
+ graph_infer_state.decode_att_state = _NoopDecodeState()
+ graph_infer_state.mem_index = torch.tensor([10, 11], device="cuda")
+ graph_infer_state.make_routing_capture_callback = manager.make_capture_callback_factory(
+ graph_infer_state.mem_index
+ )
+ capture_callback = graph_infer_state.make_routing_capture_callback(0)
+ topk_ids = torch.tensor([[1, 2], [3, 4]], device="cuda")
+
+ graph = torch.cuda.CUDAGraph()
+ torch.cuda.synchronize()
+ with torch.cuda.graph(graph):
+ capture_callback(topk_ids)
+
+ graph.replay()
+ result = manager.extract_routing_data(graph_infer_state.mem_index)
+ np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8))
+
+ new_infer_state = InferStateInfo()
+ new_infer_state.decode_att_state = _NoopDecodeState()
+ new_infer_state.mem_index = torch.tensor([20, 21], device="cuda")
+ new_topk_ids = torch.tensor([[5, 6], [7, 8]], device="cuda")
+ graph_infer_state.copy_for_cuda_graph(new_infer_state)
+ topk_ids.copy_(new_topk_ids)
+
+ graph.replay()
+
+ result = manager.extract_routing_data(new_infer_state.mem_index)
+ np.testing.assert_array_equal(result[:, 0, :], new_topk_ids.cpu().numpy().astype(np.uint8))
+
+ def test_microbatch_isolation(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+ mem0 = torch.tensor([10, 11], device="cuda")
+ mem1 = torch.tensor([20, 21], device="cuda")
+ ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda")
+ ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2
+
+ capture0 = manager.make_capture_callback_factory(mem0)(0)
+ capture1 = manager.make_capture_callback_factory(mem1)(0)
+ capture0(ids_0)
+ capture1(ids_1)
+
+ result0 = manager.extract_routing_data(mem0)
+ result1 = manager.extract_routing_data(mem1)
+ assert result0[0, 0, 0] == 1
+ assert result1[0, 0, 0] == 2
+
+ def test_dtype_selection_uint8(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=256,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.uint8
+ assert manager.np_dtype == np.uint8
+ assert manager.dtype_id == 1
+
+ def test_dtype_selection_int16(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=257,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.int16
+ assert manager.np_dtype == np.int16
+ assert manager.dtype_id == 2
+
+ def test_extract_preserves_uint8_values(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=256,
+ kv_cache_size=64,
+ max_capture_tokens=16,
+ )
+ mem_indexes = torch.tensor([0, 1, 2], device="cuda")
+ topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 255, 3]], device="cuda")
+
+ manager.make_capture_callback_factory(mem_indexes)(0)(topk_ids)
+
+ result = manager.extract_routing_data(mem_indexes)
+ np.testing.assert_array_equal(result[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8))
+
+ def test_routing_buffer_and_pointer_shape(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=48,
+ topk=8,
+ num_experts=256,
+ kv_cache_size=2048,
+ max_capture_tokens=256,
+ )
+ assert manager.routing_buffer.shape == (2048, 48, 8)
+ assert manager.routing_buffer.dtype == torch.uint8
+ assert manager.routing_buffer.device.type == "cpu"
+ assert manager.routing_buffer.is_pinned()
+ assert manager.routing_buffer_ptr.shape == (1,)
+ assert manager.routing_buffer_ptr.dtype == torch.uint64
+ assert manager.routing_buffer_ptr.device.type == "cuda"
+
+ def test_partial_token_capture(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda")
+ topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda")
+
+ manager.make_capture_callback_factory(mem_indexes[:3])(0)(topk_ids)
+
+ result_written = manager.extract_routing_data(mem_indexes[:3])
+ np.testing.assert_array_equal(result_written[:, 0, :], topk_ids.cpu().numpy().astype(np.uint8))
+
+ result_unwritten = manager.extract_routing_data(mem_indexes[3:])
+ np.testing.assert_array_equal(result_unwritten[:, 0, :], np.zeros((2, 2), dtype=np.uint8))
+
+ def test_capture_does_not_allocate_capture_buffer(self):
+ _skip_without_cuda()
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=256,
+ )
+ assert not hasattr(manager, "_capture_buffer")
diff --git a/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py b/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py
new file mode 100644
index 0000000000..2bd9bdf06c
--- /dev/null
+++ b/unit_tests/common/basemodel/triton_kernel/test_routing_capture.py
@@ -0,0 +1,133 @@
+import pytest
+import torch
+
+from lightllm.common.basemodel.triton_kernel.routing_capture import scatter_routing_topk_to_cpu
+
+
+def _skip_without_cuda():
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for Triton kernels.")
+
+
+@pytest.mark.parametrize("dtype,dtype_id,max_value", [(torch.uint8, 1, 255), (torch.int16, 2, 1024)])
+def test_scatter_routing_topk_to_pinned_cpu(dtype, dtype_id, max_value):
+ _skip_without_cuda()
+
+ num_tokens = 5
+ num_moe_layers = 3
+ topk = 4
+ kv_cache_size = 32
+ moe_layer_index = 1
+
+ topk_ids = torch.randint(0, max_value, (num_tokens, topk), dtype=torch.int64, device="cuda")
+ mem_indexes = torch.tensor([17, 3, 29, 8, 11], dtype=torch.int32, device="cuda")
+ routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=dtype,
+ device="cpu",
+ pin_memory=True,
+ )
+ routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda")
+
+ scatter_routing_topk_to_cpu(
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=routing_buffer_ptr,
+ moe_layer_index=moe_layer_index,
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ dtype_id=dtype_id,
+ )
+ torch.cuda.synchronize()
+
+ expected = torch.zeros_like(routing_buffer)
+ expected[mem_indexes.cpu().long(), moe_layer_index, :] = topk_ids.cpu().to(dtype)
+ assert torch.equal(routing_buffer, expected)
+
+
+def test_scatter_routing_topk_respects_layer_index():
+ _skip_without_cuda()
+
+ num_tokens = 3
+ num_moe_layers = 2
+ topk = 2
+ kv_cache_size = 16
+
+ topk_ids = torch.arange(num_tokens * topk, dtype=torch.int64, device="cuda").view(num_tokens, topk)
+ mem_indexes = torch.tensor([10, 4, 13], dtype=torch.int64, device="cuda")
+ routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=torch.int16,
+ device="cpu",
+ pin_memory=True,
+ )
+ routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda")
+
+ scatter_routing_topk_to_cpu(
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=routing_buffer_ptr,
+ moe_layer_index=1,
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ dtype_id=2,
+ )
+ torch.cuda.synchronize()
+
+ expected = torch.zeros_like(routing_buffer)
+ expected[mem_indexes.cpu(), 1, :] = topk_ids.cpu().to(torch.int16)
+ assert torch.equal(routing_buffer, expected)
+
+
+def test_scatter_routing_topk_is_cuda_graph_capturable():
+ _skip_without_cuda()
+
+ num_tokens = 4
+ num_moe_layers = 2
+ topk = 3
+ kv_cache_size = 16
+
+ topk_ids = torch.arange(num_tokens * topk, dtype=torch.int64, device="cuda").view(num_tokens, topk)
+ mem_indexes = torch.tensor([2, 4, 6, 8], dtype=torch.int32, device="cuda")
+ routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=torch.uint8,
+ device="cpu",
+ pin_memory=True,
+ )
+ routing_buffer_ptr = torch.tensor([routing_buffer.data_ptr()], dtype=torch.uint64, device="cuda")
+
+ scatter_routing_topk_to_cpu(
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=routing_buffer_ptr,
+ moe_layer_index=0,
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ dtype_id=1,
+ )
+ torch.cuda.synchronize()
+ routing_buffer.zero_()
+
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph):
+ scatter_routing_topk_to_cpu(
+ topk_ids=topk_ids,
+ mem_indexes=mem_indexes,
+ routing_buffer_ptr=routing_buffer_ptr,
+ moe_layer_index=0,
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ dtype_id=1,
+ )
+
+ graph.replay()
+ torch.cuda.synchronize()
+
+ expected = torch.zeros_like(routing_buffer)
+ expected[mem_indexes.cpu(), 0, :] = topk_ids.cpu().to(torch.uint8)
+ assert torch.equal(routing_buffer, expected)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
index 605433e9d8..dfeda0b6f7 100644
--- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
+++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
@@ -230,5 +230,32 @@ def test_case9():
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
+def test_case10():
+ """
+ 测试场景:测试 flush_cache 函数
+ """
+ print("\nTest Case 10: Testing flush_cache function\n")
+ tree = RadixCache("unique_name", 100, 0)
+ tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
+ tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is not None
+ assert size == 3
+ tree.flush_cache()
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is None
+ assert size == 0
+ assert tree.get_tree_total_tokens_num() == 0
+ assert tree.get_refed_tokens_num() == 0
+ assert len(tree.root_node.children) == 0
+ assert tree.root_node.token_id_key.numel() == 0
+ assert tree.root_node.token_mem_index_value.numel() == 0
+ assert tree.root_node.ref_counter == 1
+
+
if __name__ == "__main__":
pytest.main()