Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/CN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ PD 分离模式参数

激进调度可能导致解码期间频繁的预填充中断。禁用它可以让 router_max_wait_tokens 参数更有效地工作。

.. option:: --enable_prefill_decode_mixed

在同一次推理调度步骤中混合执行 prefill 与 decode。

仅支持 ``--run_mode`` 为 ``normal`` 时开启。当同时存在 prefill 与 decode 请求时,调度器会在同一步内
先执行 prefill、再执行 decode,而不是在激进调度下只执行 prefill、阻塞 decode,从而在有新 prefill
请求时也能推进 decode,提升整体吞吐。

不能与 ``--enable_prefill_microbatch_overlap`` 或 ``--enable_decode_microbatch_overlap`` 同时使用。

.. option:: --disable_dynamic_prompt_cache

禁用kv cache 缓存
Expand Down
11 changes: 11 additions & 0 deletions docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ Scheduling Parameters

Aggressive scheduling may cause frequent prefill interruptions during decoding. Disabling it can make the router_max_wait_tokens parameter work more effectively.

.. option:: --enable_prefill_decode_mixed

Enable mixed prefill and decode scheduling in the same inference step.

Only supported when ``--run_mode`` is ``normal``. When both prefill and decode requests are pending,
the scheduler runs prefill first and then decode in one scheduling step, instead of running only
prefill under aggressive scheduling. This improves decode throughput when new prefill requests arrive.

Cannot be used together with ``--enable_prefill_microbatch_overlap`` or
``--enable_decode_microbatch_overlap``.

.. option:: --disable_dynamic_prompt_cache

Disable kv cache caching
Expand Down
32 changes: 31 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
Expand Down Expand Up @@ -498,6 +498,16 @@ def _prefill(
self,
model_input: ModelInput,
):
if self.args.enable_prefill_decode_mixed and model_input.b_is_decode_req is not None:
gather_token_prefill_decode_mixed(
input_ids=model_input.input_ids,
req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
b_req_idx=model_input.b_req_idx,
b_mtp_index=model_input.b_mtp_index,
b_is_decode_req=model_input.b_is_decode_req,
b_prefill_start_loc=model_input.b_prefill_start_loc,
)

origin_handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num
origin_batch_size = model_input.batch_size

Expand Down Expand Up @@ -697,6 +707,26 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
model_input0.to_cuda()
model_input1.to_cuda()

if self.args.enable_prefill_decode_mixed and model_input0.b_is_decode_req is not None:
gather_token_prefill_decode_mixed(
input_ids=model_input0.input_ids,
req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
b_req_idx=model_input0.b_req_idx,
b_mtp_index=model_input0.b_mtp_index,
b_is_decode_req=model_input0.b_is_decode_req,
b_prefill_start_loc=model_input0.b_prefill_start_loc,
)

if self.args.enable_prefill_decode_mixed and model_input1.b_is_decode_req is not None:
gather_token_prefill_decode_mixed(
input_ids=model_input1.input_ids,
req_to_next_token_ids=self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
b_req_idx=model_input1.b_req_idx,
b_mtp_index=model_input1.b_mtp_index,
b_is_decode_req=model_input1.b_is_decode_req,
b_prefill_start_loc=model_input1.b_prefill_start_loc,
)

assert model_input0.mem_indexes.is_cuda
assert model_input1.mem_indexes.is_cuda

Expand Down
11 changes: 11 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class ModelInput:
b_req_idx: torch.Tensor = None
b_mtp_index: torch.Tensor = None
b_seq_len: torch.Tensor = None
# 在 prefill 阶段,用于在 enable_prefill_decode_mixed 开启下,
# 用于标识请求是否为 decode 请求混合在 prefill 请求中。
# 其对应的 input_ids 需要特殊处理, 从 req_to_next_token_ids 中获取。

b_is_decode_req: torch.Tensor = None

# 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录共享的radix cache中的长度
b_shared_seq_len: torch.Tensor = None
# 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录请求间的共享关系。
Expand Down Expand Up @@ -52,6 +58,11 @@ def to_cuda(self):
self.input_ids = self.input_ids.cuda(non_blocking=True)
if self.mem_indexes is None:
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)

if self.b_is_decode_req is not None:
self.b_is_decode_req = self.b_is_decode_req.cuda(non_blocking=True)
assert self.is_prefill

self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
Expand Down
235 changes: 235 additions & 0 deletions lightllm/common/basemodel/triton_kernel/gather_token_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,75 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b
return output


@triton.jit
def _fwd_kernel_gather_prefill_decode_mixed(
input_ids,
req_to_next_token_ids,
req_to_next_token_ids_stride,
req_to_next_token_ids_stride_1,
b_req_idx,
b_mtp_index,
b_is_decode_req,
b_prefill_start_loc,
num_size,
BLOCK: tl.constexpr,
):
block_index = tl.program_id(0)
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
block_mask = block_range < num_size
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
cur_next_token_id = tl.load(
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask
)
cur_is_decode_req = tl.load(b_is_decode_req + block_range, mask=block_mask, other=False)
cur_prefill_start_loc = tl.load(b_prefill_start_loc + block_range, mask=block_mask, other=-1)

tl.store(input_ids + cur_prefill_start_loc, cur_next_token_id, mask=block_mask & cur_is_decode_req)
return


def gather_token_prefill_decode_mixed(
input_ids: torch.Tensor,
req_to_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
b_mtp_index: torch.Tensor,
b_is_decode_req: torch.Tensor,
b_prefill_start_loc: torch.Tensor,
):
"""
This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
Args:
input_ids: (batch_size,)
req_to_next_token_ids: (max_req_num, max_mtp_step)
b_req_idx: (batch_size,)
b_mtp_index: (batch_size,)
b_is_decode_req: (batch_size,)
b_prefill_start_loc: (batch_size,)
Returns:
input_ids:
"""
batch_size = b_req_idx.shape[0]
BLOCK = 256
grid = (triton.cdiv(batch_size, BLOCK),)
num_warps = 1
_fwd_kernel_gather_prefill_decode_mixed[grid](
input_ids=input_ids,
req_to_next_token_ids=req_to_next_token_ids,
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
b_is_decode_req=b_is_decode_req,
b_prefill_start_loc=b_prefill_start_loc,
num_size=batch_size,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return input_ids


def test_scatter_token_to_cpu():
batch_size = 30
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
Expand All @@ -166,6 +235,172 @@ def test_gather_token():
print("test_gather_token passed")


def _ref_gather_token_prefill_decode_mixed(
input_ids: torch.Tensor,
req_to_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
b_mtp_index: torch.Tensor,
b_is_decode_req: torch.Tensor,
b_prefill_start_loc: torch.Tensor,
) -> torch.Tensor:
out = input_ids.clone()
table = req_to_next_token_ids.detach().cpu()
req_idx_cpu = b_req_idx.detach().cpu()
mtp_cpu = b_mtp_index.detach().cpu()
is_decode_cpu = b_is_decode_req.detach().cpu()
start_loc_cpu = b_prefill_start_loc.detach().cpu()
for i in range(req_idx_cpu.shape[0]):
if is_decode_cpu[i].item():
rid = int(req_idx_cpu[i].item())
mid = int(mtp_cpu[i].item())
loc = int(start_loc_cpu[i].item())
out[loc] = table[rid, mid]
return out


def _run_gather_token_prefill_decode_mixed_case(
input_ids: torch.Tensor,
req_to_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
b_mtp_index: torch.Tensor,
b_is_decode_req: torch.Tensor,
b_prefill_start_loc: torch.Tensor,
):
input_cuda = input_ids.clone().cuda()
req_table = req_to_next_token_ids.cuda()
b_req_idx_cuda = b_req_idx.cuda()
b_mtp_index_cuda = b_mtp_index.cuda()
b_is_decode_cuda = b_is_decode_req.cuda()
b_start_loc_cuda = b_prefill_start_loc.cuda()

expected = _ref_gather_token_prefill_decode_mixed(
input_cuda,
req_table,
b_req_idx_cuda,
b_mtp_index_cuda,
b_is_decode_cuda,
b_start_loc_cuda,
)
gather_token_prefill_decode_mixed(
input_cuda,
req_table,
b_req_idx_cuda,
b_mtp_index_cuda,
b_is_decode_cuda,
b_start_loc_cuda,
)
diff = (input_cuda - expected).abs().max()
assert diff < 1e-6, f"max diff {diff.item()}"


def test_gather_token_prefill_decode_mixed_decode_only():
"""仅 decode 行:按 b_prefill_start_loc 写入 req_to_next_token_ids 中的 next token。"""
req_to_next_token_ids = torch.zeros((32, 4), dtype=torch.int64, device="cuda")
req_to_next_token_ids[3, 0] = 42
req_to_next_token_ids[7, 0] = 99
req_to_next_token_ids[11, 2] = 17

input_ids = torch.tensor([0, 0, 0], dtype=torch.int64, device="cuda")
b_req_idx = torch.tensor([3, 7, 11], dtype=torch.int32, device="cuda")
b_mtp_index = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda")
b_is_decode_req = torch.tensor([True, True, True], dtype=torch.bool, device="cuda")
b_prefill_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda")

_run_gather_token_prefill_decode_mixed_case(
input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc
)
print("test_gather_token_prefill_decode_mixed_decode_only passed")


def test_gather_token_prefill_decode_mixed_mixed_batch():
"""prefill + decode 混合:仅 decode 位置被覆盖,prefill token 保持不变。"""
req_to_next_token_ids = torch.zeros((16, 2), dtype=torch.int64, device="cuda")
req_to_next_token_ids[5, 0] = 9001

# prefill [10,11,12] | decode placeholder | prefill [20,21]
input_ids = torch.tensor([10, 11, 12, -1, 20, 21], dtype=torch.int64, device="cuda")
b_req_idx = torch.tensor([0, 5, 1], dtype=torch.int32, device="cuda")
b_mtp_index = torch.tensor([0, 0, 0], dtype=torch.int32, device="cuda")
b_is_decode_req = torch.tensor([False, True, False], dtype=torch.bool, device="cuda")
b_q_seq_len = torch.tensor([3, 1, 2], dtype=torch.int32, device="cuda")
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len

_run_gather_token_prefill_decode_mixed_case(
input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc
)
print("test_gather_token_prefill_decode_mixed_mixed_batch passed")


def test_gather_token_prefill_decode_mixed_prefill_only_unchanged():
"""无 decode 行时 input_ids 不应被修改。"""
req_to_next_token_ids = torch.full((8, 1), 777, dtype=torch.int64, device="cuda")
input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda")
b_req_idx = torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(3, dtype=torch.int32, device="cuda")
b_is_decode_req = torch.zeros(3, dtype=torch.bool, device="cuda")
b_q_seq_len = torch.tensor([2, 1, 1], dtype=torch.int32, device="cuda")
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len

before = input_ids.clone()
gather_token_prefill_decode_mixed(
input_ids,
req_to_next_token_ids,
b_req_idx,
b_mtp_index,
b_is_decode_req,
b_prefill_start_loc,
)
assert torch.equal(input_ids, before)
print("test_gather_token_prefill_decode_mixed_prefill_only_unchanged passed")


def test_gather_token_prefill_decode_mixed_large_batch():
"""batch_size > 256,覆盖多 block 的 triton grid。"""
batch_size = 300
max_req = 400
req_to_next_token_ids = torch.arange(max_req * 2, dtype=torch.int64, device="cuda").view(max_req, 2)
input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda")
b_req_idx = torch.arange(10, 10 + batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = (b_req_idx % 2).to(torch.int32)
b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda")
b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda")

_run_gather_token_prefill_decode_mixed_case(
input_ids, req_to_next_token_ids, b_req_idx, b_mtp_index, b_is_decode_req, b_prefill_start_loc
)
print("test_gather_token_prefill_decode_mixed_large_batch passed")


def test_gather_token_prefill_decode_mixed_roundtrip_with_scatter():
"""scatter_token 写入后,mixed gather 能读回同一 next token。"""
batch_size = 16
req_to_next_token_ids = torch.zeros((64, 3), dtype=torch.float32, pin_memory=True)
token_info = torch.arange(100, 100 + batch_size, dtype=torch.float32, device="cuda")
b_req_idx = torch.arange(4, 4 + batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
scatter_token(token_info, req_to_next_token_ids, b_req_idx, b_mtp_index)

input_ids = torch.zeros(batch_size, dtype=torch.int64, device="cuda")
b_is_decode_req = torch.ones(batch_size, dtype=torch.bool, device="cuda")
b_prefill_start_loc = torch.arange(batch_size, dtype=torch.int32, device="cuda")

gather_token_prefill_decode_mixed(
input_ids,
req_to_next_token_ids,
b_req_idx,
b_mtp_index,
b_is_decode_req,
b_prefill_start_loc,
)
assert torch.equal(input_ids, token_info.to(torch.int64))
print("test_gather_token_prefill_decode_mixed_roundtrip_with_scatter passed")


if __name__ == "__main__":
test_scatter_token_to_cpu()
test_gather_token()
test_gather_token_prefill_decode_mixed_decode_only()
test_gather_token_prefill_decode_mixed_mixed_batch()
test_gather_token_prefill_decode_mixed_prefill_only_unchanged()
test_gather_token_prefill_decode_mixed_large_batch()
test_gather_token_prefill_decode_mixed_roundtrip_with_scatter()
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="""aggressive schedule can lead to frequent prefill interruptions during decode.
disabling it allows the router_max_wait_tokens parameter to work more effectively.""",
)
parser.add_argument(
"--enable_prefill_decode_mixed",
action="store_true",
help="""when run_mode is normal, allow prefill and decode requests to run in the same
scheduling step when both exist, improving throughput under aggressive schedule.""",
)

parser.add_argument(
"--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use."
Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def normal_or_p_d_start(args):
if args.enable_prefill_microbatch_overlap or args.enable_decode_microbatch_overlap:
args.enable_tpsp_mix_mode = True

if args.enable_prefill_decode_mixed:
assert args.run_mode == "normal", "--enable_prefill_decode_mixed only supports run_mode normal"
Comment on lines +191 to +192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation for --enable_prefill_decode_mixed explicitly states that it cannot be used together with --enable_prefill_microbatch_overlap or --enable_decode_microbatch_overlap. This constraint should be enforced here to prevent invalid configurations.

Suggested change
if args.enable_prefill_decode_mixed:
assert args.run_mode == "normal", "--enable_prefill_decode_mixed only supports run_mode normal"
if args.enable_prefill_decode_mixed:
assert args.run_mode == "normal", "--enable_prefill_decode_mixed only supports run_mode normal"
assert not args.enable_prefill_microbatch_overlap and not args.enable_decode_microbatch_overlap, \
"--enable_prefill_decode_mixed cannot be used with microbatch overlap"


if args.enable_dp_prefill_balance:
assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1"

Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class StartArgs:
router_token_ratio: float = field(default=0.0)
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
enable_prefill_decode_mixed: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=None)
disable_chunked_prefill: bool = field(default=False)
Expand Down
Loading
Loading