diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index bdfb4d721..c19cc9266 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -214,6 +214,16 @@ PD 分离模式参数 激进调度可能导致解码期间频繁的预填充中断。禁用它可以让 router_max_wait_tokens 参数更有效地工作。 +.. option:: --enable_prefill_decode_mixed + + 在同一次推理调度步骤中混合执行 prefill 与 decode。 + + 仅支持 ``--run_mode`` 为 ``normal`` 时开启。当同时存在 prefill 与 decode 请求时,调度器会在同一步内 + 先执行 prefill、再执行 decode,而不是在激进调度下只执行 prefill、阻塞 decode,从而在有新 prefill + 请求时也能推进 decode,提升整体吞吐。 + + 不能与 ``--enable_prefill_microbatch_overlap`` 或 ``--enable_decode_microbatch_overlap`` 同时使用。 + .. option:: --disable_dynamic_prompt_cache 禁用kv cache 缓存 diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index d7b798100..ad5b38130 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -213,6 +213,17 @@ Scheduling Parameters Aggressive scheduling may cause frequent prefill interruptions during decoding. Disabling it can make the router_max_wait_tokens parameter work more effectively. +.. option:: --enable_prefill_decode_mixed + + Enable mixed prefill and decode scheduling in the same inference step. + + Only supported when ``--run_mode`` is ``normal``. When both prefill and decode requests are pending, + the scheduler runs prefill first and then decode in one scheduling step, instead of running only + prefill under aggressive scheduling. This improves decode throughput when new prefill requests arrive. + + Cannot be used together with ``--enable_prefill_microbatch_overlap`` or + ``--enable_decode_microbatch_overlap``. + .. option:: --disable_dynamic_prompt_cache Disable kv cache caching diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 05aaaadca..c46617ae9 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -22,7 +22,7 @@ from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg -from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token +from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -498,6 +498,16 @@ def _prefill( self, model_input: ModelInput, ): + if self.args.enable_prefill_decode_mixed and model_input.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input.b_req_idx, + b_mtp_index=model_input.b_mtp_index, + b_is_decode_req=model_input.b_is_decode_req, + b_prefill_start_loc=model_input.b_prefill_start_loc, + ) + origin_handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num origin_batch_size = model_input.batch_size @@ -697,6 +707,26 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod model_input0.to_cuda() model_input1.to_cuda() + if self.args.enable_prefill_decode_mixed and model_input0.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input0.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input0.b_req_idx, + b_mtp_index=model_input0.b_mtp_index, + b_is_decode_req=model_input0.b_is_decode_req, + b_prefill_start_loc=model_input0.b_prefill_start_loc, + ) + + if self.args.enable_prefill_decode_mixed and model_input1.b_is_decode_req is not None: + gather_token_prefill_decode_mixed( + input_ids=model_input1.input_ids, + req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_idx=model_input1.b_req_idx, + b_mtp_index=model_input1.b_mtp_index, + b_is_decode_req=model_input1.b_is_decode_req, + b_prefill_start_loc=model_input1.b_prefill_start_loc, + ) + assert model_input0.mem_indexes.is_cuda assert model_input1.mem_indexes.is_cuda diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 758c0b519..1795ff9a8 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -20,6 +20,12 @@ class ModelInput: b_req_idx: torch.Tensor = None b_mtp_index: torch.Tensor = None b_seq_len: torch.Tensor = None + # 在 prefill 阶段,用于在 enable_prefill_decode_mixed 开启下, + # 用于标识请求是否为 decode 请求混合在 prefill 请求中。 + # 其对应的 input_ids 需要特殊处理, 从 req_to_next_token_ids 中获取。 + + b_is_decode_req: torch.Tensor = None + # 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录共享的radix cache中的长度 b_shared_seq_len: torch.Tensor = None # 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录请求间的共享关系。 @@ -52,6 +58,11 @@ def to_cuda(self): self.input_ids = self.input_ids.cuda(non_blocking=True) if self.mem_indexes is None: self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True) + + if self.b_is_decode_req is not None: + self.b_is_decode_req = self.b_is_decode_req.cuda(non_blocking=True) + assert self.is_prefill + self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) diff --git a/lightllm/common/basemodel/triton_kernel/gather_token_id.py b/lightllm/common/basemodel/triton_kernel/gather_token_id.py index f8181d73c..16c7528b3 100644 --- a/lightllm/common/basemodel/triton_kernel/gather_token_id.py +++ b/lightllm/common/basemodel/triton_kernel/gather_token_id.py @@ -141,6 +141,75 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b return output +@triton.jit +def _fwd_kernel_gather_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + req_to_next_token_ids_stride, + req_to_next_token_ids_stride_1, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + num_size, + BLOCK: tl.constexpr, +): + block_index = tl.program_id(0) + block_range = block_index * BLOCK + tl.arange(0, BLOCK) + block_mask = block_range < num_size + cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask) + cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask) + cur_next_token_id = tl.load( + req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask + ) + cur_is_decode_req = tl.load(b_is_decode_req + block_range, mask=block_mask, other=False) + cur_prefill_start_loc = tl.load(b_prefill_start_loc + block_range, mask=block_mask, other=-1) + + tl.store(input_ids + cur_prefill_start_loc, cur_next_token_id, mask=block_mask & cur_is_decode_req) + return + + +def gather_token_prefill_decode_mixed( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +): + """ + This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor). + Args: + input_ids: (batch_size,) + req_to_next_token_ids: (max_req_num, max_mtp_step) + b_req_idx: (batch_size,) + b_mtp_index: (batch_size,) + b_is_decode_req: (batch_size,) + b_prefill_start_loc: (batch_size,) + Returns: + input_ids: + """ + batch_size = b_req_idx.shape[0] + BLOCK = 256 + grid = (triton.cdiv(batch_size, BLOCK),) + num_warps = 1 + _fwd_kernel_gather_prefill_decode_mixed[grid]( + input_ids=input_ids, + req_to_next_token_ids=req_to_next_token_ids, + req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), + req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1), + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + b_is_decode_req=b_is_decode_req, + b_prefill_start_loc=b_prefill_start_loc, + num_size=batch_size, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return input_ids + + def test_scatter_token_to_cpu(): batch_size = 30 req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True) @@ -166,6 +235,172 @@ def test_gather_token(): print("test_gather_token passed") +def _ref_gather_token_prefill_decode_mixed( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +) -> torch.Tensor: + out = input_ids.clone() + table = req_to_next_token_ids.detach().cpu() + req_idx_cpu = b_req_idx.detach().cpu() + mtp_cpu = b_mtp_index.detach().cpu() + is_decode_cpu = b_is_decode_req.detach().cpu() + start_loc_cpu = b_prefill_start_loc.detach().cpu() + for i in range(req_idx_cpu.shape[0]): + if is_decode_cpu[i].item(): + rid = int(req_idx_cpu[i].item()) + mid = int(mtp_cpu[i].item()) + loc = int(start_loc_cpu[i].item()) + out[loc] = table[rid, mid] + return out + + +def _run_gather_token_prefill_decode_mixed_case( + input_ids: torch.Tensor, + req_to_next_token_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_mtp_index: torch.Tensor, + b_is_decode_req: torch.Tensor, + b_prefill_start_loc: torch.Tensor, +): + input_cuda = input_ids.clone().cuda() + req_table = req_to_next_token_ids.cuda() + b_req_idx_cuda = b_req_idx.cuda() + b_mtp_index_cuda = b_mtp_index.cuda() + b_is_decode_cuda = b_is_decode_req.cuda() + b_start_loc_cuda = b_prefill_start_loc.cuda() + + expected = _ref_gather_token_prefill_decode_mixed( + input_cuda, + req_table, + b_req_idx_cuda, + b_mtp_index_cuda, + b_is_decode_cuda, + b_start_loc_cuda, + ) + gather_token_prefill_decode_mixed( + input_cuda, + req_table, + b_req_idx_cuda, + b_mtp_index_cuda, + b_is_decode_cuda, + b_start_loc_cuda, + ) + diff = (input_cuda - expected).abs().max() + assert diff < 1e-6, f"max diff {diff.item()}" + + +def test_gather_token_prefill_decode_mixed_decode_only(): + """仅 decode 行:按 b_prefill_start_loc 写入 req_to_next_token_ids 中的 next token。""" + req_to_next_token_ids = torch.zeros((32, 4), dtype=torch.int64, device="cuda") + req_to_next_token_ids[3, 0] = 42 + req_to_next_token_ids[7, 0] = 99 + req_to_next_token_ids[11, 2] = 17 + + input_ids = torch.tensor([0, 0, 0], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([3, 7, 11], dtype=torch.int32, device="cuda") + b_mtp_index = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda") + b_is_decode_req = torch.tensor([True, True, True], dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda") + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_decode_only passed") + + +def test_gather_token_prefill_decode_mixed_mixed_batch(): + """prefill + decode 混合:仅 decode 位置被覆盖,prefill token 保持不变。""" + req_to_next_token_ids = torch.zeros((16, 2), dtype=torch.int64, device="cuda") + req_to_next_token_ids[5, 0] = 9001 + + # prefill [10,11,12] | decode placeholder | prefill [20,21] + input_ids = torch.tensor([10, 11, 12, -1, 20, 21], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([0, 5, 1], dtype=torch.int32, device="cuda") + b_mtp_index = torch.tensor([0, 0, 0], dtype=torch.int32, device="cuda") + b_is_decode_req = torch.tensor([False, True, False], dtype=torch.bool, device="cuda") + b_q_seq_len = torch.tensor([3, 1, 2], dtype=torch.int32, device="cuda") + b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_mixed_batch passed") + + +def test_gather_token_prefill_decode_mixed_prefill_only_unchanged(): + """无 decode 行时 input_ids 不应被修改。""" + req_to_next_token_ids = torch.full((8, 1), 777, dtype=torch.int64, device="cuda") + input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(3, dtype=torch.int32, device="cuda") + b_is_decode_req = torch.zeros(3, dtype=torch.bool, device="cuda") + b_q_seq_len = torch.tensor([2, 1, 1], dtype=torch.int32, device="cuda") + b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len + + before = input_ids.clone() + gather_token_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + ) + assert torch.equal(input_ids, before) + print("test_gather_token_prefill_decode_mixed_prefill_only_unchanged passed") + + +def test_gather_token_prefill_decode_mixed_large_batch(): + """batch_size > 256,覆盖多 block 的 triton grid。""" + batch_size = 300 + max_req = 400 + req_to_next_token_ids = torch.arange(max_req * 2, dtype=torch.int64, device="cuda").view(max_req, 2) + input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda") + b_req_idx = torch.arange(10, 10 + batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = (b_req_idx % 2).to(torch.int32) + b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda") + + _run_gather_token_prefill_decode_mixed_case( + input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc + ) + print("test_gather_token_prefill_decode_mixed_large_batch passed") + + +def test_gather_token_prefill_decode_mixed_roundtrip_with_scatter(): + """scatter_token 写入后,mixed gather 能读回同一 next token。""" + batch_size = 16 + req_to_next_token_ids = torch.zeros((64, 3), dtype=torch.float32, pin_memory=True) + token_info = torch.arange(100, 100 + batch_size, dtype=torch.float32, device="cuda") + b_req_idx = torch.arange(4, 4 + batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + scatter_token(token_info, req_to_next_token_ids, b_req_idx, b_mtp_index) + + input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda") + b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda") + b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda") + + gather_token_prefill_decode_mixed( + input_ids, + req_to_next_token_ids, + b_req_idx, + b_mtp_index, + b_is_decode_req, + b_prefill_start_loc, + ) + assert torch.equal(input_ids, token_info.to(torch.int64)) + print("test_gather_token_prefill_decode_mixed_roundtrip_with_scatter passed") + + if __name__ == "__main__": test_scatter_token_to_cpu() test_gather_token() + test_gather_token_prefill_decode_mixed_decode_only() + test_gather_token_prefill_decode_mixed_mixed_batch() + test_gather_token_prefill_decode_mixed_prefill_only_unchanged() + test_gather_token_prefill_decode_mixed_large_batch() + test_gather_token_prefill_decode_mixed_roundtrip_with_scatter() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f44a35af5..6b30ab687 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -313,6 +313,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""aggressive schedule can lead to frequent prefill interruptions during decode. disabling it allows the router_max_wait_tokens parameter to work more effectively.""", ) + parser.add_argument( + "--enable_prefill_decode_mixed", + action="store_true", + help="""when run_mode is normal, allow prefill and decode requests to run in the same + scheduling step when both exist, improving throughput under aggressive schedule.""", + ) parser.add_argument( "--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use." diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index e1182f2f7..249839b0a 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -188,6 +188,9 @@ def normal_or_p_d_start(args): if args.enable_prefill_microbatch_overlap or args.enable_decode_microbatch_overlap: args.enable_tpsp_mix_mode = True + if args.enable_prefill_decode_mixed: + assert args.run_mode == "normal", "--enable_prefill_decode_mixed only supports run_mode normal" + if args.enable_dp_prefill_balance: assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index fe9cb6161..05ff2658e 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -77,6 +77,7 @@ class StartArgs: router_token_ratio: float = field(default=0.0) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) + enable_prefill_decode_mixed: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=None) disable_chunked_prefill: bool = field(default=False) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f..ef93ad5fd 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -679,6 +679,19 @@ def _get_classed_reqs( paused_reqs=paused_reqs, is_master_in_dp=self.is_master_in_dp, can_alloc_token_num=can_alloc_token_num ) + # 在 enable_prefill_decode_mixed 模式下,如果存在 prefill 请求和 decode 请求, + # 并且 prefill 请求需要的 token 数量 + decode 请求需要的 token 数量小于等于 batch_max_tokens, + # 则将 decode 请求合并到 prefill 请求中。 + if self.args.enable_prefill_decode_mixed and len(prefill_reqs) > 0 and len(decode_reqs) > 0: + if prefill_tokens + len(decode_reqs) <= self.batch_max_tokens: + for decode_req in decode_reqs: + # 给 decode req 添加一个属性标签,标识其为混合prefill的请求。 + # 在 prefill 阶段,会根据这个属性标签, 对这些请求的处理进行一些 + # 特殊化,主要时构建获取input_ids 的方式。 + decode_req.is_decode_req_mixed_in_prefill = True + prefill_reqs.append(decode_req) + decode_reqs = [] + return prefill_reqs, decode_reqs def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 03ac4cfb0..d3796b639 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -36,6 +36,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_is_decode_req = [] for req in req_objs: @@ -57,6 +58,14 @@ def padded_prepare_prefill_inputs( b_ready_cache_len.append(req.cur_kv_len) b_mtp_index.append(0) + # enable_prefill_decode_mixed 模式下,decode 请求混合在 prefill 请求中。 + # 需要的特殊标记。 + if hasattr(req, "is_decode_req_mixed_in_prefill"): + b_is_decode_req.append(True) + del req.is_decode_req_mixed_in_prefill + else: + b_is_decode_req.append(False) + # padding fake req for prefill for _ in range(padded_req_num): input_ids.append([1]) @@ -69,6 +78,7 @@ def padded_prepare_prefill_inputs( total_token_num += 1 prefix_total_token_num += 0 batch_multimodal_params.append({"images": [], "audios": []}) + b_is_decode_req.append(False) max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) @@ -78,6 +88,7 @@ def padded_prepare_prefill_inputs( input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") + b_is_decode_req = torch.tensor(b_is_decode_req, dtype=torch.bool, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") @@ -110,6 +121,7 @@ def padded_prepare_prefill_inputs( b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_is_decode_req=b_is_decode_req, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 4eb8c7e1e..ae1af1956 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -22,6 +22,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_ready_cache_len = [] b_mtp_index = [] b_prefill_has_output = [] + b_is_decode_req = [] for req in req_objs: run_reqs.append(req) @@ -47,6 +48,11 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> prefix_total_token_num += req.cur_kv_len b_ready_cache_len.append(req.cur_kv_len) b_mtp_index.append(0) + if hasattr(req, "is_decode_req_mixed_in_prefill"): + b_is_decode_req.append(True) + del req.is_decode_req_mixed_in_prefill + else: + b_is_decode_req.append(False) max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) @@ -56,6 +62,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") + b_is_decode_req = torch.tensor(b_is_decode_req, dtype=torch.bool, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") @@ -79,6 +86,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, + b_is_decode_req=b_is_decode_req, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py index dbce73a94..3ef039543 100644 --- a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -7,7 +7,12 @@ def prepare_mtp_prefill_inputs( model_input: ModelInput, b_next_token_ids: torch.Tensor, mtp_draft_input_hiddens: torch.Tensor ): + # enable_prefill_decode_mixed 模式下,decode 请求混合在 prefill 请求中。 + # 但是mtp的input_ids已经是恢复ok,已经是正常的input_ids, 所以移除掉 b_is_decode_req。 + # 防止在 forward 阶段,因为 b_is_decode_req 不为空,导致 input_ids 被特殊处理。 new_model_input = copy.copy(model_input) + new_model_input.b_is_decode_req = None + new_input_ids = gen_mtp_new_input_ids( input_ids=model_input.input_ids, b_next_token_ids=b_next_token_ids,