From 88cc50c393c4832f028d052b61d0df4fab2df1e1 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 18 Jun 2026 12:44:35 +0000 Subject: [PATCH 01/14] fix stream fc for qwen3_coder --- lightllm/server/function_call_parser.py | 174 +++++++++++++++++++++++- 1 file changed, 169 insertions(+), 5 deletions(-) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index dfcb2f8d9..6d42ca87a 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1758,6 +1758,10 @@ def __init__(self): self.parameter_regex = re.compile( r"|(?=)|$)", re.DOTALL ) + self.parameter_stream_regex = re.compile( + r"]*)>(.*?)(|(?=)|$)", re.DOTALL + ) + self.function_start_regex = re.compile(r"]*)>", re.DOTALL) self._normal_text_buffer = "" def has_tool_call(self, text: str) -> bool: @@ -1890,6 +1894,111 @@ def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools return json.dumps(param_dict, ensure_ascii=False) + def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str: + prop = param_config.get(param_name, {}) + return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + def _json_string_prefix(self, value: str, closed: bool) -> str: + dumped = json.dumps(value, ensure_ascii=False) + return dumped if closed else dumped[:-1] + + def _strip_partial_xml_suffix(self, value: str) -> str: + for token in ("", "", self.eot_token): + max_len = min(len(value), len(token) - 1) + for suffix_len in range(max_len, 0, -1): + if token.startswith(value[-suffix_len:]): + return value[:-suffix_len] + return value + + def _build_streaming_arguments_json( + self, + func_name: str, + partial_body: str, + tools: List[Tool], + close_object: bool = False, + ) -> Optional[str]: + """Build a monotonic JSON arguments prefix for XML tool-call streaming.""" + param_config = self._get_param_config(func_name, tools) + parts = ["{"] + has_param = False + + for match in self.parameter_stream_regex.finditer(partial_body): + param_name = match.group(1).strip() + if not param_name: + continue + + param_value = match.group(2) + closed = match.group(3) == "" + if not closed: + param_value = self._strip_partial_xml_suffix(param_value) + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + if closed: + param_value = param_value[:-1] + else: + # Keep the template delimiter newline buffered until we know + # whether it is part of the value or just precedes . + param_value = param_value[:-1] + + if has_param: + parts.append(", ") + parts.append(json.dumps(param_name, ensure_ascii=False)) + parts.append(": ") + + param_type = self._get_qwen3_param_type(param_name, param_config) + if param_type in ("string", "str", "enum") or param_name not in param_config: + parts.append(self._json_string_prefix(param_value, closed)) + elif closed: + converted = self._convert_param_value(param_value, param_name, param_config, func_name) + converted_json = json.dumps(converted, ensure_ascii=False) + if converted_json.startswith(param_value): + parts.append(converted_json) + else: + parts.append(param_value) + else: + parts.append(param_value) + + has_param = True + if not closed: + break + + if not has_param and close_object: + return "{}" + if not has_param: + return None + + if close_object: + parts.append("}") + + return "".join(parts) + + def _ensure_qwen3_stream_state(self, tool_index: int) -> None: + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + def _append_qwen3_arguments_delta(self, calls: List[ToolCallItem], tool_index: int, current_args_json: str) -> None: + sent_args = self.streamed_args_for_tool[tool_index] + if not current_args_json.startswith(sent_args): + logger.warning( + "Qwen3-Coder streaming arguments are not monotonic for tool index %s; skip delta.", + tool_index, + ) + return + + argument_diff = current_args_json[len(sent_args) :] + if argument_diff: + calls.append( + ToolCallItem( + tool_index=tool_index, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[tool_index] += argument_diff + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: idx = text.find(self.bot_token) normal_text = text[:idx].strip() if idx != -1 else text @@ -1941,23 +2050,78 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami self._buffer = current_text[tool_call_start:] current_text = self._buffer + if self.current_tool_id == -1: + self.current_tool_id = 0 + + self._ensure_qwen3_stream_state(self.current_tool_id) + + function_match = self.function_start_regex.search(current_text) + if not function_match: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + func_name = function_match.group(1).strip() eot_pos = current_text.find(self.eot_token) + if func_name not in self._tool_indices: + if eot_pos == -1: + return StreamingParseResult(normal_text=normal_text, calls=calls) + logger.warning(f"Model attempted to call undefined function: {func_name}") + self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() + self.current_tool_name_sent = False + continue + + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + function_close_pos = current_text.find("", function_match.end()) + close_object = function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos) + partial_end = function_close_pos if close_object else eot_pos + if partial_end == -1: + partial_end = len(current_text) + partial_body = current_text[function_match.end() : partial_end] + current_args_json = self._build_streaming_arguments_json( + func_name, + partial_body, + tools, + close_object=close_object, + ) + if current_args_json: + self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json) + if eot_pos == -1: return StreamingParseResult(normal_text=normal_text, calls=calls) complete_block = current_text[: eot_pos + len(self.eot_token)] func_matches = self.function_regex.findall(complete_block) - if self.current_tool_id == -1: - self.current_tool_id = 0 - for match in func_matches: func_str = match[0] if match[0] else match[1] item = self._parse_function_call(func_str, tools) if item: - item.tool_index = self.current_tool_id - calls.append(item) + completed_tool_id = self.current_tool_id + try: + parsed_args = json.loads(item.parameters) + except json.JSONDecodeError: + parsed_args = {} + self.prev_tool_call_arr[completed_tool_id] = { + "name": item.name, + "arguments": parsed_args, + } + sent_args = self.streamed_args_for_tool[completed_tool_id] + if item.parameters.startswith(sent_args): + self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters) self.current_tool_id += 1 + self.current_tool_name_sent = False self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() From a62103d588e6f364b48da99030d21e643485d275 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 18 Jun 2026 21:58:30 +0800 Subject: [PATCH 02/14] fix(qwen3_coder): correct streaming tool-call argument parsing Keep streamed_args_for_tool a byte-exact prefix of json.dumps(arguments) so the serving-layer stop-time reconciliation can't produce duplicated/invalid JSON. String values stream incrementally; non-string values are emitted only once arrives (a partial number/array/bool isn't a prefix of its json.dumps form). Detect "value still streaming" via the terminator rather than match position ($ matches before the template's trailing newline). Also: - don't crash (IndexError) on >=2 in one block, and emit a name head for each; flush via _ensure_qwen3_stream_state. - an undefined first function no longer discards the whole block (a valid function after it is still emitted). - treat as an implicit function close so the closing '}' rides in the same args delta (no separate trailing-'}' delta to double-count). - drop dead _build_partial_arguments_json; dedup newline-strip / param-type helpers. Add unit_tests/server/test_qwen3_coder_stream_fc.py covering these across chunk boundaries. --- lightllm/server/function_call_parser.py | 248 ++++++++---------- .../server/test_qwen3_coder_stream_fc.py | 194 ++++++++++++++ 2 files changed, 310 insertions(+), 132 deletions(-) create mode 100644 unit_tests/server/test_qwen3_coder_stream_fc.py diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 6d42ca87a..093c0c568 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1786,8 +1786,7 @@ def _convert_param_value(self, value: str, param_name: str, param_config: Dict, if param_name not in param_config: return value - prop = param_config.get(param_name, {}) - param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + param_type = self._get_qwen3_param_type(param_name, param_config) if param_type in ("string", "str", "enum"): return value @@ -1837,12 +1836,7 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional except ValueError: continue param_name = match[:idx].strip() - param_value = match[idx + 1 :] - # Strip leading/trailing newlines from value - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] + param_value = self._strip_value_newlines(match[idx + 1 :]) param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) @@ -1852,55 +1846,17 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional parameters=json.dumps(param_dict, ensure_ascii=False), ) - def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools: List[Tool]) -> Optional[str]: - """Build the current argument JSON from a partial XML tool-call body.""" - param_matches = self.parameter_regex.findall(partial_body) - if not param_matches: - return None - - param_config = self._get_param_config(func_name, tools) - param_dict = {} - has_visible_value = False - - for match in param_matches: - try: - idx = match.index(">") - except ValueError: - continue - - param_name = match[:idx].strip() - param_value = match[idx + 1 :] - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - if param_value.strip(): - has_visible_value = True - elif ( - f"" in partial_body - and f"{param_value}" in partial_body - ): - # Closed empty-string parameter. We can safely emit it. - has_visible_value = True - else: - # Parameter tag is present but its value has not started streaming yet. - continue - - param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) - - if not param_dict and not has_visible_value: - return None - - return json.dumps(param_dict, ensure_ascii=False) - def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str: prop = param_config.get(param_name, {}) return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" - def _json_string_prefix(self, value: str, closed: bool) -> str: - dumped = json.dumps(value, ensure_ascii=False) - return dumped if closed else dumped[:-1] + def _strip_value_newlines(self, value: str) -> str: + """Strip the single leading/trailing newline the Qwen3 template wraps each value in.""" + if value.startswith("\n"): + value = value[1:] + if value.endswith("\n"): + value = value[:-1] + return value def _strip_partial_xml_suffix(self, value: str) -> str: for token in ("", "", self.eot_token): @@ -1917,7 +1873,14 @@ def _build_streaming_arguments_json( tools: List[Tool], close_object: bool = False, ) -> Optional[str]: - """Build a monotonic JSON arguments prefix for XML tool-call streaming.""" + """Build a monotonic JSON arguments prefix for XML tool-call streaming. + + The result is always a byte-exact prefix of json.dumps(final_arguments) so the + serving layer (api_openai.py) can reconcile the streamed args at stream end. + String values stream character-by-character (a string prefix stays a prefix); + non-string values are only emitted once their arrives, because a + partial number/array/bool is not guaranteed to be a prefix of its json.dumps form. + """ param_config = self._get_param_config(func_name, tools) parts = ["{"] has_param = False @@ -1927,46 +1890,45 @@ def _build_streaming_arguments_json( if not param_name: continue - param_value = match.group(2) - closed = match.group(3) == "" - if not closed: - param_value = self._strip_partial_xml_suffix(param_value) - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - if closed: - param_value = param_value[:-1] - else: - # Keep the template delimiter newline buffered until we know - # whether it is part of the value or just precedes . - param_value = param_value[:-1] + # The value is complete only when an explicit closed it, or a + # sibling follows. Otherwise it is still streaming. + # (We can't key off match.end()==len: `$` matches before a trailing newline, + # and the template wraps every value in one, which would look "complete".) + rest = partial_body[match.end() :] + value_open = ( + match.group(3) != "" + and not rest.startswith("") + ) if has_param: parts.append(", ") parts.append(json.dumps(param_name, ensure_ascii=False)) parts.append(": ") + has_param = True param_type = self._get_qwen3_param_type(param_name, param_config) - if param_type in ("string", "str", "enum") or param_name not in param_config: - parts.append(self._json_string_prefix(param_value, closed)) - elif closed: - converted = self._convert_param_value(param_value, param_name, param_config, func_name) - converted_json = json.dumps(converted, ensure_ascii=False) - if converted_json.startswith(param_value): - parts.append(converted_json) - else: - parts.append(param_value) + is_string = param_type in ("string", "str", "enum") + + if value_open: + # In-progress (and therefore last) parameter. + if is_string: + value = self._strip_value_newlines(self._strip_partial_xml_suffix(match.group(2))) + # Drop the closing quote so the stream stays an extendable prefix. + parts.append(json.dumps(value, ensure_ascii=False)[:-1]) + # Non-string values cannot be emitted as a safe partial prefix, so stop + # after the key and wait for the value to close. + return "".join(parts) + + value = self._strip_value_newlines(match.group(2)) + if is_string: + parts.append(json.dumps(value, ensure_ascii=False)) else: - parts.append(param_value) - - has_param = True - if not closed: - break + converted = self._convert_param_value(value, param_name, param_config, func_name) + parts.append(json.dumps(converted, ensure_ascii=False)) - if not has_param and close_object: - return "{}" if not has_param: - return None + return "{}" if close_object else None if close_object: parts.append("}") @@ -2061,42 +2023,51 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami func_name = function_match.group(1).strip() eot_pos = current_text.find(self.eot_token) - if func_name not in self._tool_indices: - if eot_pos == -1: - return StreamingParseResult(normal_text=normal_text, calls=calls) - logger.warning(f"Model attempted to call undefined function: {func_name}") - self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() - self.current_tool_name_sent = False - continue + func_defined = func_name in self._tool_indices - if not self.current_tool_name_sent: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=func_name, - parameters="", + # Undefined function whose block has not finished yet: wait for more text + # (the block may also contain a valid function we shouldn't drop). + if not func_defined and eot_pos == -1: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + if func_defined: + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) ) - ) - self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id] = { - "name": func_name, - "arguments": {}, - } + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } - function_close_pos = current_text.find("", function_match.end()) - close_object = function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos) - partial_end = function_close_pos if close_object else eot_pos - if partial_end == -1: - partial_end = len(current_text) - partial_body = current_text[function_match.end() : partial_end] - current_args_json = self._build_streaming_arguments_json( - func_name, - partial_body, - tools, - close_object=close_object, - ) - if current_args_json: - self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json) + # The function body is complete once we hit either or the + # enclosing ; treating eot as an implicit close lets us emit + # the closing '}' inside the same args delta (so the serving layer's + # stop-time reconciliation sees one delta, not a separate trailing '}'). + function_close_pos = current_text.find("", function_match.end()) + if function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos): + partial_end = function_close_pos + close_object = True + elif eot_pos != -1: + partial_end = eot_pos + close_object = True + else: + partial_end = len(current_text) + close_object = False + partial_body = current_text[function_match.end() : partial_end] + current_args_json = self._build_streaming_arguments_json( + func_name, + partial_body, + tools, + close_object=close_object, + ) + if current_args_json: + self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json) if eot_pos == -1: return StreamingParseResult(normal_text=normal_text, calls=calls) @@ -2104,24 +2075,37 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami complete_block = current_text[: eot_pos + len(self.eot_token)] func_matches = self.function_regex.findall(complete_block) + # Flush every completed function in the block. _parse_function_call returns + # None for undefined ones, so they are skipped without advancing the index. for match in func_matches: func_str = match[0] if match[0] else match[1] item = self._parse_function_call(func_str, tools) - if item: - completed_tool_id = self.current_tool_id - try: - parsed_args = json.loads(item.parameters) - except json.JSONDecodeError: - parsed_args = {} - self.prev_tool_call_arr[completed_tool_id] = { - "name": item.name, - "arguments": parsed_args, - } - sent_args = self.streamed_args_for_tool[completed_tool_id] - if item.parameters.startswith(sent_args): - self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters) - self.current_tool_id += 1 - self.current_tool_name_sent = False + if not item: + continue + completed_tool_id = self.current_tool_id + self._ensure_qwen3_stream_state(completed_tool_id) + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=completed_tool_id, + name=item.name, + parameters="", + ) + ) + self.current_tool_name_sent = True + try: + parsed_args = json.loads(item.parameters) + except json.JSONDecodeError: + parsed_args = {} + self.prev_tool_call_arr[completed_tool_id] = { + "name": item.name, + "arguments": parsed_args, + } + sent_args = self.streamed_args_for_tool[completed_tool_id] + if item.parameters.startswith(sent_args): + self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters) + self.current_tool_id += 1 + self.current_tool_name_sent = False self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() diff --git a/unit_tests/server/test_qwen3_coder_stream_fc.py b/unit_tests/server/test_qwen3_coder_stream_fc.py new file mode 100644 index 000000000..f72a642b1 --- /dev/null +++ b/unit_tests/server/test_qwen3_coder_stream_fc.py @@ -0,0 +1,194 @@ +"""Unit tests for Qwen3-Coder XML streaming tool-call parsing. + +These drive ``Qwen3CoderDetector.parse_streaming_increment`` directly (no server), +reassembling the per-tool argument string exactly the way ``api_openai.py`` does for +streamed responses, including the ``finish_reason == "stop"`` reconciliation. The key +invariant under test: the reassembled arguments are always valid JSON equal to what the +one-shot ``detect_and_parse`` would produce, for every chunk boundary. +""" + +import json +import pytest + +from lightllm.server.api_models import Function, Tool +from lightllm.server.function_call_parser import Qwen3CoderDetector + +CHUNK_SIZES = [1, 2, 3, 5, 13, 10_000] + + +def _tool(name, properties): + return Tool( + type="function", + function=Function(name=name, description="", parameters={"type": "object", "properties": properties}), + ) + + +def _stream_and_reassemble(text, tools, chunk): + """Feed ``text`` to the detector in fixed-size chunks and rebuild the client-visible + tool calls the way api_openai.py stream_results does (stop-rewrite on the last chunk).""" + det = Qwen3CoderDetector() + chunks = [text[i : i + chunk] for i in range(0, len(text), chunk)] + per_tool = {} + for ci, piece in enumerate(chunks): + result = det.parse_streaming_increment(piece, tools) + is_last = ci == len(chunks) - 1 + for call in result.calls: + ti = call.tool_index + per_tool.setdefault(ti, {"name": None, "args": ""}) + if call.name is not None: + per_tool[ti]["name"] = call.name + params = call.parameters + if is_last and params: + # Mirror api_openai.py:559-575 (REPLACE semantics). + latest_delta_len = len(params) + expected = json.dumps(det.prev_tool_call_arr[ti].get("arguments", {}), ensure_ascii=False) + actual = det.streamed_args_for_tool[ti] + if latest_delta_len > 0: + actual = actual[:-latest_delta_len] + params = expected.replace(actual, "", 1) + if params: + per_tool[ti]["args"] += params + return det, per_tool + + +def _assert_tool_calls(text, tools, expected, chunk_sizes=CHUNK_SIZES): + """expected: {tool_index: (name, args_dict)}.""" + for chunk in chunk_sizes: + det, per_tool = _stream_and_reassemble(text, tools, chunk) + assert len(per_tool) == len(expected), f"chunk={chunk}: tool count {len(per_tool)} != {len(expected)}" + for ti, (name, args) in expected.items(): + got = per_tool[ti] + assert got["name"] == name, f"chunk={chunk}: tool {ti} name {got['name']!r} != {name!r}" + parsed = json.loads(got["args"]) # must be valid JSON + assert parsed == args, f"chunk={chunk}: tool {ti} args {parsed!r} != {args!r}" + + +@pytest.mark.parametrize("chunk", CHUNK_SIZES) +def test_single_string_param(chunk): + text = ( + "\n\n\n" + "San Francisco\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("get_weather", {"location": {"type": "string"}})], + {0: ("get_weather", {"location": "San Francisco"})}, + [chunk], + ) + + +def test_array_param_compact_spacing(): + # Regression: a non-string value whose raw text ("[1,2]") differs from json.dumps + # ("[1, 2]") used to break the streamed-args prefix invariant -> duplicated/invalid JSON. + text = "\n\n\n[1,2]\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"nums": {"type": "array"}})], {0: ("calc", {"nums": [1, 2]})}) + + +def test_number_param_reformatted(): + # "1.0" is parsed to int 1 and json.dumps'd as "1"; the stream must agree. + text = "\n\n\n1.0\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"v": {"type": "number"}})], {0: ("calc", {"v": 1})}) + + +def test_boolean_param(): + text = "\n\n\ntrue\n\n\n" + _assert_tool_calls(text, [_tool("set", {"flag": {"type": "boolean"}})], {0: ("set", {"flag": True})}) + + +def test_object_param(): + text = '\n\n\n{"a":1,"b":[2,3]}\n\n\n' + _assert_tool_calls(text, [_tool("f", {"cfg": {"type": "object"}})], {0: ("f", {"cfg": {"a": 1, "b": [2, 3]}})}) + + +def test_two_params_mixed_types(): + text = ( + "\n\n\nNYC\n\n" + "\n3\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("f", {"city": {"type": "string"}, "days": {"type": "integer"}})], + {0: ("f", {"city": "NYC", "days": 3})}, + ) + + +def test_multiline_string_value(): + text = ( + "\n\n\nline1\nline2\nline3\n" "\n\n" + ) + _assert_tool_calls(text, [_tool("f", {"code": {"type": "string"}})], {0: ("f", {"code": "line1\nline2\nline3"})}) + + +def test_string_with_json_special_chars(): + text = '\n\n\nsay "hi"\\path\n\n\n' + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": 'say "hi"\\path'})}) + + +def test_empty_string_value(): + text = "\n\n\n\n\n\n" + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": ""})}) + + +def test_no_param_function(): + text = "\n\n\n" + _assert_tool_calls(text, [_tool("ping", {})], {0: ("ping", {})}) + + +def test_two_separate_tool_call_blocks(): + text = ( + "\n\n\nhi\n\n\n\n" + "\n\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_two_functions_in_one_block(): + # Regression: used to raise IndexError on the second function in a single block. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_undefined_then_valid_in_same_block(): + # Regression: an undefined first function used to discard the whole block, dropping + # the valid call that followed it. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls(text, [_tool("valid", {"y": {"type": "string"}})], {0: ("valid", {"y": "yo"})}) + + +def test_truncated_call_missing_function_close(): + # Regression: a typed value with no before used to leave the + # streamed args unterminated (missing closing brace). + text = "\n\n\n0.50\n\n" + _assert_tool_calls(text, [_tool("calc", {"x": {"type": "number"}})], {0: ("calc", {"x": 0.5})}) + + +def test_streaming_matches_non_stream(): + # The reassembled streamed args must equal the one-shot detect_and_parse output. + tools = [_tool("f", {"city": {"type": "string"}, "n": {"type": "integer"}, "tags": {"type": "array"}})] + text = ( + "\n\n\nLondon\n\n" + '\n7\n\n\n["a","b"]\n\n\n' + ) + oneshot = Qwen3CoderDetector().detect_and_parse(text, tools) + expected_args = json.loads(oneshot.calls[0].parameters) + _assert_tool_calls(text, tools, {0: ("f", expected_args)}) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"])) From 7093bdadf88cd4227f7492f148e6d793608473b0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 11:26:32 +0800 Subject: [PATCH 03/14] refactor(logging): colored levels, windowed cache stats, quieter per-request logs - Add ANSI color codes to log level names (TTY only; plain in files) - Introduce SystemStatusReporter with windowed prefix-cache hit rate alongside the global rate, plus a more compact status line - Drop gunicorn --access-logfile flags (FastAPI middleware now handles it) - Remove duplicate _ACCESS_LOG_STATUS_COLORS declaration in api_http.py - Downgrade noisy per-request / per-batch progress logs from INFO to DEBUG - Fix flake8 F841 (unused exception variable) in detokenization manager --- lightllm/common/basemodel/cuda_graph.py | 15 ++- lightllm/common/triton_utils/autotuner.py | 6 +- lightllm/server/api_http.py | 1 - lightllm/server/api_start.py | 6 - lightllm/server/detokenization/manager.py | 9 +- lightllm/server/httpserver/manager.py | 8 +- .../httpserver_for_pd_master/manager.py | 2 +- lightllm/server/router/batch.py | 8 +- lightllm/server/router/manager.py | 58 +++++---- lightllm/server/router/stats.py | 123 +++++++++++++++++- lightllm/utils/log_utils.py | 72 ++++++---- 11 files changed, 228 insertions(+), 80 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661..7a24fd17f 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -4,6 +4,7 @@ import bisect import triton from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager @@ -197,7 +198,11 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -252,7 +257,13 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572f..3cbc5dc0f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key): cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key)) if os.path.exists(cache_file): - logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") + logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) return True @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size): option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS, ) ) - logger.info(f"Saved configs for {self.kernel_name} - {_static_key}") + logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}") - logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished") + logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished") def _mutate_args_clone(self, args, kwargs): origin_list = [] diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 270e2a8cf..e127f0931 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -116,7 +116,6 @@ def set_args(self, args: StartArgs): app = FastAPI() g_objs.app = app -_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} _ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} _ACCESS_LOG_RESET = "\033[0m" diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cf431d65..d191ee394 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -495,8 +495,6 @@ def normal_or_p_d_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -566,8 +564,6 @@ def pd_master_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -662,8 +658,6 @@ def config_server_start(args): f"{args.config_server_host}:{args.config_server_port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.config_server.api_http:app", diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 8c213914c..ae7103700 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): req.link_prompt_ids_shm_array() req.link_logprobs_shm_array() - logger.info( + logger.debug( f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s" ) @@ -76,7 +76,10 @@ def handle_loop(self): for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) - self._add_new_group_req_index(recv_obj=recv_obj) + try: + self._add_new_group_req_index(recv_obj=recv_obj) + except Exception: + logger.exception("add new group req index has exception") # 当队列中存在较多的请求时,将一次接受的数量上调 recv_max_count = min(int(recv_max_count * 1.3), 256) @@ -160,7 +163,7 @@ def remove_finished_reqs(self): for decode_req in finished_reqs: decode_req.req.can_released_mark = True - logger.info(f"detoken release req id {decode_req.req.request_id}") + logger.debug(f"detoken release req id {decode_req.req.request_id}") self.shm_req_manager.put_back_req_obj(decode_req.req) self.req_id_to_out.pop(decode_req.request_id, None) return diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0f1b87311..374b089da 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -506,7 +506,7 @@ async def _log_req_header(self, request_headers, group_request_id: int): x_session_id = request_headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " @@ -744,7 +744,7 @@ async def _wait_to_token_package( (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_start_time} " f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " @@ -856,8 +856,8 @@ async def recycle_resource_loop(self): if req_status is None: continue - logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + logger.debug( + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 104da9f26..3d73dee47 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -197,7 +197,7 @@ async def _log_req_header(self, request: Request, group_request_id: int): x_request_id = request.headers.get("X-Request-Id", "") x_session_id = request.headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 24d0b9b82..34902d812 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union from lightllm.server.core.objs import ShmReqManager, Req from lightllm.utils.log_utils import init_logger -from .stats import RouterStatics logger = init_logger(__name__) @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]: all_dp_req_num[req.sample_params.suggested_dp_index] += 1 return all_dp_req_num - def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics): + def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids = [] for req in self.reqs: if req.shm_infer_released: - logger.info(f"router release req id {req.request_id}") - if not req.is_aborted: - router_statics.update(req.candetoken_out_len) - + logger.debug(f"router release req id {req.request_id}") shm_req_manager.put_back_req_obj(req) req = None else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index dfb886660..f88ef53eb 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -16,6 +16,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from .stats import SystemStatusReporter from lightllm.server.core.objs.io_objs import ( GroupReqIndexes, AbortedReqCmd, @@ -25,7 +26,7 @@ from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer -from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.log_utils import init_logger from lightllm.utils.profiler import ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient @@ -68,6 +69,7 @@ def __init__(self, args: StartArgs): self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.status_reporter = None # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -197,6 +199,11 @@ async def wait_to_model_ready(self): ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") + self.status_reporter = SystemStatusReporter( + args=self.args, + max_total_token_num=self.max_total_token_num, + dp_size_in_node=self.dp_size_in_node, + ) if self.args.run_mode == "prefill": from lightllm.server.router.model_infer.mode_backend.pd.prefill_node_impl import ( @@ -226,24 +233,9 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: + # Count output tokens (each running req produces ~1 token per decode step) + self.status_reporter.count_output_tokens(len(self.running_batch.reqs)) if counter_count % 100 == 0: - for dp_index in range(self.dp_size_in_node): - token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num - token_ratio2 = ( - self.max_total_token_num - - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - ) / self.max_total_token_num - d_i = dp_index - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) - paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - logger.debug( - f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" - f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" - f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" - ) - logger.debug(self.router_statics.log_str()) self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) @@ -265,11 +257,15 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) - # 60s print once - if log_time_ready("token_load_info", 60): - for dp_i in range(self.dp_size_in_node): - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") + + self.status_reporter.maybe_print( + running_batch=self.running_batch, + req_queue=self.req_queue, + read_only_statics_mem_manager=self.read_only_statics_mem_manager, + paused_req_num=self._get_paused_req_num(), + radix_cache_client=self.radix_cache_client, + disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache, + ) await asyncio.sleep(self._get_schedule_time_interval()) @@ -300,6 +296,7 @@ async def _step(self): async def _add_batch(self, batch: Batch): # 添加新请求 + self.status_reporter.count_prompt_tokens(batch.input_tokens()) reqs = [r.to_router_rpc_obj() for r in batch.reqs] while not self.shm_reqs_io_buffer.is_empty(): await asyncio.sleep(0.001) @@ -344,7 +341,16 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics) + # Capture finished req stats before filtering + for req in self.running_batch.reqs: + if req.shm_infer_released: + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=req.shm_cur_output_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) + self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None return @@ -416,7 +422,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req._router_stop_str_matched = False req_group.append(req) - logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") + logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s") self.req_queue.extend(req_group) self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return @@ -431,6 +437,8 @@ def _generate_new_batch(self): logger.info(f"generate new batch, {new_batch.simple_log()}") self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) + if self.schedule_new_batch is not None: + logger.debug(f"gen new batch, {self.schedule_new_batch.simple_log()}") return def _multinode_tp_generate_new_batch(self): diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index b715c5bcb..85548d913 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,7 +1,126 @@ -from lightllm.utils.log_utils import init_logger +import time +import logging from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_system_status_logger -logger = init_logger(__name__) +logger = logging.getLogger(__name__) + + +class SystemStatusReporter: + def __init__(self, args, max_total_token_num, dp_size_in_node): + self.enabled = not args.disable_log_stats + self.interval = max(5, args.log_stats_interval) + if args.log_stats_interval < 5: + logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s") + self.max_total_token_num = max_total_token_num + self.dp_size_in_node = dp_size_in_node + self.status_logger = init_system_status_logger("router") + + # Accumulation counters (reset each interval) + self.last_print_time = time.time() + self.prompt_tokens = 0 + self.output_tokens = 0 + + # Windowed counters for cache hit (reset each interval) + self.window_input_total = 0 + self.window_cache_total = 0 + + # Global counters (never reset, for lifetime stats) + self.global_input_total = 0 + self.global_cache_total = 0 + self.global_mtp_output_total = 0 + self.global_mtp_accepted_total = 0 + + def count_prompt_tokens(self, num_tokens: int): + if self.enabled: + self.prompt_tokens += num_tokens + + def count_output_tokens(self, num_tokens: int): + if self.enabled: + self.output_tokens += num_tokens + + def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): + if self.enabled: + self.window_input_total += input_len + self.window_cache_total += cache_len + self.global_input_total += input_len + self.global_cache_total += cache_len + self.global_mtp_output_total += output_len + self.global_mtp_accepted_total += mtp_accepted + + def maybe_print( + self, + running_batch, + req_queue, + read_only_statics_mem_manager, + paused_req_num=0, + radix_cache_client=None, + disable_dynamic_prompt_cache=False, + ): + if not self.enabled: + return + now = time.time() + elapsed = now - self.last_print_time + if elapsed < self.interval: + return + + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed + input_tps = self.prompt_tokens / elapsed + output_tps = self.output_tokens / elapsed + + running = len(running_batch.reqs) if running_batch else 0 + queued = req_queue.get_wait_req_num() + + # Memory utilization (average across dp) + # kv_used: physical KV memory usage (includes prefix cache tree occupancy) + # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens + kv_used_list = [] + kv_used_no_cache_list = [] + for dp_i in range(self.dp_size_in_node): + unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i) + used = self.max_total_token_num - unrefed + kv_used_list.append(used / self.max_total_token_num) + if not disable_dynamic_prompt_cache and radix_cache_client is not None: + cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i) + kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num) + else: + kv_used_no_cache_list.append(used / self.max_total_token_num) + avg_kv_used = sum(kv_used_list) / len(kv_used_list) + avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list) + + # Windowed prefix cache hit rate (this interval only) + window_cache_hit_rate = ( + (self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0 + ) + # Global prefix cache hit rate (lifetime) + global_cache_hit_rate = ( + (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 + ) + + kv_pct = avg_kv_used * 100 + kv_pct_no_cache = avg_kv_used_no_cache * 100 + + # Avg MTP accepted length (only shown when MTP is active) + mtp_suffix = "" + if self.global_mtp_accepted_total > 0: + decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total + avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) + mtp_suffix = f", MTP {avg_mtp_len:.2f}" + + self.status_logger.info( + f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), " + f"REQ {running}run, {queued}wait, {paused_req_num}pause, " + f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), " + f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)" + f"{mtp_suffix}" + ) + + # Reset windowed counters + self.prompt_tokens = 0 + self.output_tokens = 0 + self.window_input_total = 0 + self.window_cache_total = 0 + self.last_print_time = now class RouterStatics: diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5c..c3057d18f 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -4,24 +4,41 @@ import logging import sys import os -import time from typing import Optional _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" -_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") +_STATUS_FORMAT = "%(levelname)s [%(asctime)s] %(message)s" + +_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info") _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) +# ANSI color codes +_RESET = "\033[0m" +_LEVEL_COLORS = { + logging.DEBUG: "\033[36m", # cyan + logging.INFO: "\033[32m", # green + logging.WARNING: "\033[33m", # yellow + logging.ERROR: "\033[31m", # red + logging.CRITICAL: "\033[1;31m", # bold red +} + class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + """Adds logging prefix to newlines to align multi-line messages, with optional color on levelname.""" - def __init__(self, fmt, datefmt=None): + def __init__(self, fmt, datefmt=None, use_color=False): logging.Formatter.__init__(self, fmt, datefmt) + self.use_color = use_color def format(self, record): + if self.use_color: + color = _LEVEL_COLORS.get(record.levelno, "") + if color: + record = logging.makeLogRecord(record.__dict__) + record.levelname = color + record.levelname + _RESET msg = logging.Formatter.format(self, record) if record.message != "": parts = msg.split(record.message) @@ -39,7 +56,9 @@ def _setup_logger(): _root_logger.setLevel(_LOG_LEVEL) global _default_handler global _default_file_handler - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _use_color = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + color_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=_use_color) + plain_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) if _default_handler is None: _default_handler = logging.StreamHandler(sys.stdout) @@ -55,10 +74,10 @@ def _setup_logger(): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) - _default_file_handler.setFormatter(fmt) + _default_file_handler.setFormatter(plain_fmt) _root_logger.addHandler(_default_file_handler) - _default_handler.setFormatter(fmt) + _default_handler.setFormatter(color_fmt) # Setting this will avoid the message # being propagated to the parent logger. _root_logger.propagate = False @@ -89,29 +108,28 @@ def init_logger(name: str): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _inference_log_file_handler[pid] = logging.FileHandler(_LOG_DIR + f"/process.{pid}.log") _inference_log_file_handler[pid].setLevel(_LOG_LEVEL) - _inference_log_file_handler[pid].setFormatter(NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)) + _inference_log_file_handler[pid].setFormatter( + NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) + ) _root_logger.addHandler(_inference_log_file_handler[pid]) logger.addHandler(_inference_log_file_handler[pid]) logger.propagate = False return logger -_log_time_mark_dict = {} - - -def log_time_ready(mark_name, time_count: int): - """ - time_count 间隔时间超过多少s调用该函数会返回True,否则返回False - 用于控制一些日志输出的频率 - """ - global _log_time_mark_dict - - if mark_name not in _log_time_mark_dict: - _log_time_mark_dict[mark_name] = time.time() - return False - cur_time_mark = time.time() - if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count: - _log_time_mark_dict[mark_name] = cur_time_mark - return True - else: - return False +def init_system_status_logger(name: str): + logger = logging.getLogger(f"lightllm.status.{name}") + if not logger.handlers: + logger.setLevel(logging.INFO) + fmt = logging.Formatter(_STATUS_FORMAT, datefmt=_DATE_FORMAT) + handler = logging.StreamHandler(sys.stdout) + handler.flush = sys.stdout.flush + handler.setFormatter(fmt) + logger.addHandler(handler) + if _LOG_DIR is not None: + file_handler = logging.FileHandler(os.path.join(_LOG_DIR, f"status.{name}.log")) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(fmt) + logger.addHandler(file_handler) + logger.propagate = False + return logger From c7502e4fe3930857b8682622ec710abe8da27e39 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 11:26:51 +0800 Subject: [PATCH 04/14] fix(httpserver): reject oversized prompts and translate ValueError to 400 Reject prompts whose character length exceeds max_req_total_len * 8 before tokenization, so a long string can no longer reach the tokenizer and stall the loop. The raised ValueError is caught one level up: log it at WARNING, release any held multimodal resources, abort the in-flight group request, and re-raise so the API layer (which already maps ValueError to HTTP 400) returns a graceful error to the client instead of a 500. --- lightllm/server/httpserver/manager.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 374b089da..945184340 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -317,6 +317,13 @@ async def generate( # 用于等待 pd_master 下发的交换信息 pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + if isinstance(prompt, str): + max_prompt_chars = self.max_req_total_len * 8 + if len(prompt) > max_prompt_chars: + raise ValueError( + f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " + f"the request is rejected before tokenization." + ) start_time = time.time() request_headers = request.headers if request is not None else {} @@ -467,6 +474,12 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status + except ValueError as e: + logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}") + if group_request_id not in self.req_id_to_out_inf: + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e except (ClientDisconnected, Exception) as e: logger.warning(f"group_request_id: {group_request_id} has exception {str(e)}") From de49c0e71ce227626ca43bbd32d76bd55bbadd25 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 13:26:35 +0800 Subject: [PATCH 05/14] fix(router): restore router_statics.update() on req completion The earlier refactor that moved the finished-req loop out of Batch.filter_out_finished_req into Router._filter_reqs_from_running_batch forgot to keep the router_statics.update(candetoken_out_len) call, freezing ema_req_out_len at its initial value. Multiple schedulers (chunked_prefill, beam, pd_decode, nixl_pd) read that EMA for KV-budget estimation, so leaving it stale degraded scheduling accuracy. --- lightllm/server/router/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f88ef53eb..62d82e072 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -350,6 +350,7 @@ def _filter_reqs_from_running_batch(self): cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) + self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None From eafb4a8cd2177e8d283ea23690f1b6a66c7c033c Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 15:28:54 +0800 Subject: [PATCH 06/14] fix(router): output TPS via per-req deltas, skip aborted reqs in stats MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two correctness fixes flagged in PR review: 1. count_output_tokens(len(running_batch.reqs)) once per router loop is wrong — the router loop polls on schedule_time_interval, decoupled from inference, so this overcounts when the loop is faster than decode and undercounts when slower, and includes paused/prefill-only reqs. Track shm_cur_output_len per request and accumulate the delta each tick (with a tail settlement when the req is filtered out so we don't lose its last tokens to the post-final-tick window). 2. on_request_completed() and router_statics.update() now both run for aborted requests, whose candetoken_out_len is a short partial value. Restore the prior `if not req.is_aborted` guard so disconnects don't bias the output-length EMA used by KV-budget estimators. --- lightllm/server/router/manager.py | 44 ++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 62d82e072..6dfd9b102 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -70,6 +70,9 @@ def __init__(self, args: StartArgs): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None self.status_reporter = None + # Track shm_cur_output_len per running request to compute per-tick deltas + # for accurate output TPS regardless of router schedule interval. + self._req_last_output_len: Dict[int, int] = {} # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -233,8 +236,18 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: - # Count output tokens (each running req produces ~1 token per decode step) - self.status_reporter.count_output_tokens(len(self.running_batch.reqs)) + # Count output tokens via per-request shm_cur_output_len deltas, since the + # router loop runs on schedule_time_interval and len(reqs) is not a per-step + # token count. + new_output_tokens = 0 + for req in self.running_batch.reqs: + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.get(req.request_id, 0) + if cur_out_len > prev_out_len: + new_output_tokens += cur_out_len - prev_out_len + self._req_last_output_len[req.request_id] = cur_out_len + if new_output_tokens: + self.status_reporter.count_output_tokens(new_output_tokens) if counter_count % 100 == 0: self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently @@ -343,14 +356,25 @@ def _filter_reqs_from_running_batch(self): if self.running_batch is not None: # Capture finished req stats before filtering for req in self.running_batch.reqs: - if req.shm_infer_released: - self.status_reporter.on_request_completed( - input_len=req.input_len, - output_len=req.shm_cur_output_len, - cache_len=req.prompt_cache_len, - mtp_accepted=req.mtp_accepted_token_num, - ) - self.router_statics.update(req.candetoken_out_len) + if not req.shm_infer_released: + continue + # Settle any output-token delta produced after the last router tick + # so windowed TPS does not lose the request's tail tokens. + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.pop(req.request_id, 0) + if cur_out_len > prev_out_len: + self.status_reporter.count_output_tokens(cur_out_len - prev_out_len) + # Aborted/disconnected requests can leave a partial output_len that + # would bias the EMA toward shorter generations; skip them. + if req.is_aborted: + continue + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=cur_out_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) + self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None From f203f4b0e684dc8dd693389ee10fd819e3f2d028 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 16:05:59 +0800 Subject: [PATCH 07/14] perf(router): sweep output-token deltas once per print interval Move the per-running-req shm_cur_output_len delta tracking from the router tick (~33 Hz) into SystemStatusReporter.maybe_print, which only runs once per log_stats_interval (>= 5s). The reporter now owns the per-req snapshot dict and exposes discard_req(req) for tail settlement when a req leaves the running batch, so the router loop's hot path no longer walks the batch every schedule cycle. Output TPS accuracy is unchanged: still based on real shm_cur_output_len deltas, with tail tokens settled at completion. --- lightllm/server/router/manager.py | 28 ++++++---------------------- lightllm/server/router/stats.py | 28 +++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 6dfd9b102..15e05afb7 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -70,9 +70,6 @@ def __init__(self, args: StartArgs): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None self.status_reporter = None - # Track shm_cur_output_len per running request to compute per-tick deltas - # for accurate output TPS regardless of router schedule interval. - self._req_last_output_len: Dict[int, int] = {} # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -236,18 +233,8 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: - # Count output tokens via per-request shm_cur_output_len deltas, since the - # router loop runs on schedule_time_interval and len(reqs) is not a per-step - # token count. - new_output_tokens = 0 - for req in self.running_batch.reqs: - cur_out_len = req.shm_cur_output_len - prev_out_len = self._req_last_output_len.get(req.request_id, 0) - if cur_out_len > prev_out_len: - new_output_tokens += cur_out_len - prev_out_len - self._req_last_output_len[req.request_id] = cur_out_len - if new_output_tokens: - self.status_reporter.count_output_tokens(new_output_tokens) + # Output-token counting is done in bulk at the print-window boundary + # inside SystemStatusReporter.maybe_print, so the router tick stays cheap. if counter_count % 100 == 0: self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently @@ -358,19 +345,16 @@ def _filter_reqs_from_running_batch(self): for req in self.running_batch.reqs: if not req.shm_infer_released: continue - # Settle any output-token delta produced after the last router tick - # so windowed TPS does not lose the request's tail tokens. - cur_out_len = req.shm_cur_output_len - prev_out_len = self._req_last_output_len.pop(req.request_id, 0) - if cur_out_len > prev_out_len: - self.status_reporter.count_output_tokens(cur_out_len - prev_out_len) + # Settle any output-token tail produced after the last window boundary, + # so windowed TPS does not lose the req's last tokens. + self.status_reporter.discard_req(req) # Aborted/disconnected requests can leave a partial output_len that # would bias the EMA toward shorter generations; skip them. if req.is_aborted: continue self.status_reporter.on_request_completed( input_len=req.input_len, - output_len=cur_out_len, + output_len=req.shm_cur_output_len, cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 85548d913..f6db924b5 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,5 +1,6 @@ import time import logging +from typing import Dict from lightllm.server.core.objs import StartArgs from lightllm.utils.log_utils import init_system_status_logger @@ -31,13 +32,23 @@ def __init__(self, args, max_total_token_num, dp_size_in_node): self.global_mtp_output_total = 0 self.global_mtp_accepted_total = 0 + # Per-req shm_cur_output_len snapshot at the previous window boundary, + # used to compute the windowed output-token count without per-tick scans. + self._req_last_output_len: Dict[int, int] = {} + def count_prompt_tokens(self, num_tokens: int): if self.enabled: self.prompt_tokens += num_tokens - def count_output_tokens(self, num_tokens: int): - if self.enabled: - self.output_tokens += num_tokens + def discard_req(self, req): + """Settle a finished/aborted req's tail output tokens (those produced after the last + window-boundary sweep) and drop its tracking entry.""" + if not self.enabled: + return + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.pop(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): if self.enabled: @@ -64,6 +75,17 @@ def maybe_print( if elapsed < self.interval: return + # Single bulk sweep at the window boundary: account for output tokens produced + # by every still-running req since the previous boundary, and refresh their + # snapshots. Reqs that finished in this window already settled via discard_req. + if running_batch is not None: + for req in running_batch.reqs: + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.get(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len + self._req_last_output_len[req.request_id] = cur_out_len + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed input_tps = self.prompt_tokens / elapsed output_tps = self.output_tokens / elapsed From 064081c2d385eb9cba5cd26d2f454609858654de Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 11:46:40 +0800 Subject: [PATCH 08/14] fix(httpserver,router): defensive group_request_id init; reorder is_aborted skip - httpserver: initialize group_request_id=None so the ValueError except handler does not hit UnboundLocalError when the oversized-prompt guard raises before alloc_req_id. - router: move the is_aborted skip after on_request_completed so aborted reqs still update completion stats, but do not pollute the router_statics EMA with their truncated output_len. --- lightllm/server/httpserver/manager.py | 4 ++++ lightllm/server/router/manager.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 945184340..6f15f46ab 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -317,7 +317,11 @@ async def generate( # 用于等待 pd_master 下发的交换信息 pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + group_request_id = None if isinstance(prompt, str): + # Guard against extremely long string prompts that might stall the tokenizer + # or cause excessive memory usage before tokenization. + # 8 characters per token is a conservative heuristic (avg is ~4). max_prompt_chars = self.max_req_total_len * 8 if len(prompt) > max_prompt_chars: raise ValueError( diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 15e05afb7..ff08963a6 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -348,16 +348,16 @@ def _filter_reqs_from_running_batch(self): # Settle any output-token tail produced after the last window boundary, # so windowed TPS does not lose the req's last tokens. self.status_reporter.discard_req(req) - # Aborted/disconnected requests can leave a partial output_len that - # would bias the EMA toward shorter generations; skip them. - if req.is_aborted: - continue self.status_reporter.on_request_completed( input_len=req.input_len, output_len=req.shm_cur_output_len, cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) + # Aborted/disconnected requests can leave a partial output_len that + # would bias the EMA toward shorter generations; skip them. + if req.is_aborted: + continue self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): From 0405484e7b001a78c4f54aa34a00c097db3765f9 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 13:40:57 +0800 Subject: [PATCH 09/14] fix(detokenization): keep req registration failures fatal --- lightllm/server/detokenization/manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index ae7103700..90181d13a 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -76,10 +76,7 @@ def handle_loop(self): for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) - try: - self._add_new_group_req_index(recv_obj=recv_obj) - except Exception: - logger.exception("add new group req index has exception") + self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 recv_max_count = min(int(recv_max_count * 1.3), 256) From 42c754b31399bb31d423a2d166a7f3853ace3f3a Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 13:58:36 +0800 Subject: [PATCH 10/14] refactor(router): improve status log format --- lightllm/server/router/stats.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index f6db924b5..dac37b67b 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -122,20 +122,24 @@ def maybe_print( kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 + log_parts = [ + f"router_status(window={elapsed:.1f}s)", + f"throughput(total={total_tps:.1f},input={input_tps:.1f},output={output_tps:.1f})", + f"req(running={running},waiting={queued},paused={paused_req_num})", + f"kv(used={kv_pct_no_cache:.1f}%)", + f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)", + ] + # Avg MTP accepted length (only shown when MTP is active) - mtp_suffix = "" if self.global_mtp_accepted_total > 0: decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) - mtp_suffix = f", MTP {avg_mtp_len:.2f}" - - self.status_logger.info( - f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), " - f"REQ {running}run, {queued}wait, {paused_req_num}pause, " - f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), " - f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)" - f"{mtp_suffix}" - ) + log_parts.append( + f"mtp(avg_tokens_per_step={avg_mtp_len:.2f}," + f"accepted={self.global_mtp_accepted_total},output={self.global_mtp_output_total})" + ) + + self.status_logger.info(" | ".join(log_parts)) # Reset windowed counters self.prompt_tokens = 0 From 9cd7c4593e49dd6a46b52d034e44a81f346fbee1 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 14:09:17 +0800 Subject: [PATCH 11/14] fix(router): remove unused status log variable --- lightllm/server/router/stats.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index dac37b67b..8aac2b036 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -119,7 +119,6 @@ def maybe_print( (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 ) - kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 log_parts = [ From 5a93c5334774f69f05ee87cd1c08b108a254ab4b Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 14:47:17 +0800 Subject: [PATCH 12/14] feat(router): add debug status details --- lightllm/server/router/stats.py | 53 +++++++++++++++++++++++++ lightllm/server/visualserver/manager.py | 2 +- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 8aac2b036..c414a4386 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,5 +1,6 @@ import time import logging +import subprocess from typing import Dict from lightllm.server.core.objs import StartArgs from lightllm.utils.log_utils import init_system_status_logger @@ -59,6 +60,44 @@ def on_request_completed(self, input_len: int, output_len: int, cache_len: int, self.global_mtp_output_total += output_len self.global_mtp_accepted_total += mtp_accepted + def _get_gpu_status_for_debug(self) -> str: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + timeout=2, + ) + except (OSError, subprocess.SubprocessError) as e: + return f"gpu=unavailable({e.__class__.__name__})" + + gpu_infos = [] + for line in result.stdout.splitlines(): + parts = [part.strip() for part in line.split(",")] + if len(parts) != 4: + continue + gpu_index, util, mem_used, mem_total = parts + try: + mem_used_mb = float(mem_used) + mem_total_mb = float(mem_total) + mem_ratio = mem_used_mb / mem_total_mb * 100 if mem_total_mb > 0 else 0.0 + mem_used_gb = mem_used_mb / 1024 + mem_total_gb = mem_total_mb / 1024 + gpu_infos.append( + f"{gpu_index}(util={float(util):.0f}%,mem={mem_ratio:.1f}%," + f"used={mem_used_gb:.1f}GiB/{mem_total_gb:.1f}GiB)" + ) + except ValueError: + continue + if not gpu_infos: + return "gpu=unavailable(empty)" + return "gpu=[" + ";".join(gpu_infos) + "]" + def maybe_print( self, running_batch, @@ -119,6 +158,7 @@ def maybe_print( (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 ) + kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 log_parts = [ @@ -139,6 +179,19 @@ def maybe_print( ) self.status_logger.info(" | ".join(log_parts)) + if logger.isEnabledFor(logging.DEBUG): + kv_unrefed_prefix_cache_pct = max(0.0, kv_pct - kv_pct_no_cache) + debug_parts = [ + "router_status_debug", + f"kv_physical={kv_pct:.1f}%", + f"kv_unrefed_prefix_cache={kv_unrefed_prefix_cache_pct:.1f}%", + f"throughput_tokens(input={self.prompt_tokens},output={self.output_tokens})", + f"gpu_cache_tokens(window={self.window_cache_total}/{self.window_input_total}," + f"global={self.global_cache_total}/{self.global_input_total})", + f"tracked_output_reqs={len(self._req_last_output_len)}", + self._get_gpu_status_for_debug(), + ] + logger.debug(" | ".join(debug_parts)) # Reset windowed counters self.prompt_tokens = 0 diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf68..ef8637853 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -176,7 +176,7 @@ async def loop_for_netio_req(self): while True: recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) if isinstance(recv_req, GroupReqIndexes): - logger.info( + logger.debug( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) From 1766c80cc32e3afbbf1a69bae0e5b3595ab8a55c Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 17 Jun 2026 11:05:09 +0800 Subject: [PATCH 13/14] fix(server): restore config access logs and prompt guard --- lightllm/server/access_log.py | 31 ++++++++++++++++ lightllm/server/api_http.py | 36 ++----------------- lightllm/server/config_server/api_http.py | 2 ++ lightllm/server/httpserver/manager.py | 23 +++--------- lightllm/server/httpserver/prompt_utils.py | 10 ++++++ .../httpserver_for_pd_master/manager.py | 2 ++ .../server/config_server/test_access_log.py | 14 ++++++++ .../httpserver/test_prompt_length_guard.py | 33 +++++++++++++++++ 8 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 lightllm/server/access_log.py create mode 100644 lightllm/server/httpserver/prompt_utils.py create mode 100644 unit_tests/server/config_server/test_access_log.py create mode 100644 unit_tests/server/httpserver/test_prompt_length_guard.py diff --git a/lightllm/server/access_log.py b/lightllm/server/access_log.py new file mode 100644 index 000000000..7365259c3 --- /dev/null +++ b/lightllm/server/access_log.py @@ -0,0 +1,31 @@ +_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} +_ACCESS_LOG_RESET = "\033[0m" + + +class _AccessLogMiddleware: + def __init__(self, app, logger): + self.app = app + self.logger = logger + + async def __call__(self, scope, receive, send): + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + status_holder = {"status": 0} + + async def send_wrapper(message): + if message["type"] == "http.response.start": + status_holder["status"] = message["status"] + await send(message) + + try: + await self.app(scope, receive, send_wrapper) + finally: + if scope["type"] == "http": + status = status_holder["status"] + msg = f"{scope['method']} {scope['path']} {status}" + color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "") + if color: + msg = color + msg + _ACCESS_LOG_RESET + self.logger.info(msg) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index e127f0931..c11778569 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -50,6 +50,7 @@ from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.access_log import _AccessLogMiddleware from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -115,40 +116,7 @@ def set_args(self, args: StartArgs): app = FastAPI() g_objs.app = app - -_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} -_ACCESS_LOG_RESET = "\033[0m" - - -class _AccessLogMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - if scope["type"] not in ("http", "websocket"): - await self.app(scope, receive, send) - return - - status_holder = {"status": 0} - - async def send_wrapper(message): - if message["type"] == "http.response.start": - status_holder["status"] = message["status"] - await send(message) - - try: - await self.app(scope, receive, send_wrapper) - finally: - if scope["type"] == "http": - status = status_holder["status"] - msg = f"{scope['method']} {scope['path']} {status}" - color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "") - if color: - msg = color + msg + _ACCESS_LOG_RESET - logger.info(msg) - - -app.add_middleware(_AccessLogMiddleware) +app.add_middleware(_AccessLogMiddleware, logger=logger) def create_error_response( diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 3ce39bb6e..8b4e234e0 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.access_log import _AccessLogMiddleware from lightllm.server.visualserver.objs import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server @@ -18,6 +19,7 @@ logger = init_logger(__name__) app = FastAPI() +app.add_middleware(_AccessLogMiddleware, logger=logger) registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} registered_visual_server_objs: Dict[str, VIT_Obj] = {} diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6f15f46ab..fda0ddd40 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -22,6 +22,7 @@ from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from .async_queue import AsyncQueue +from .prompt_utils import validate_prompt_text_length from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE @@ -252,6 +253,7 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs + validate_prompt_text_length(prompt, self.max_req_total_len) prompt_ids = self.tokenizer.encode(prompt, None, **kwargs) image_tokens = 0 img_count = 0 @@ -318,16 +320,7 @@ async def generate( pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: group_request_id = None - if isinstance(prompt, str): - # Guard against extremely long string prompts that might stall the tokenizer - # or cause excessive memory usage before tokenization. - # 8 characters per token is a conservative heuristic (avg is ~4). - max_prompt_chars = self.max_req_total_len * 8 - if len(prompt) > max_prompt_chars: - raise ValueError( - f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " - f"the request is rejected before tokenization." - ) + validate_prompt_text_length(prompt, self.max_req_total_len) start_time = time.time() request_headers = request.headers if request is not None else {} @@ -534,15 +527,7 @@ async def _encode( self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams ): if isinstance(prompt, str): - # pre-verify prompt length - # The average character length per token is always less than 8 - # TODO: automatically calculate the average character length per token - max_prompt_chars = self.max_req_total_len * 8 - if len(prompt) > max_prompt_chars: - raise ValueError( - f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " - f"the request is rejected before tokenization." - ) + validate_prompt_text_length(prompt, self.max_req_total_len) if self.enable_multimodal: assert ( len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity diff --git a/lightllm/server/httpserver/prompt_utils.py b/lightllm/server/httpserver/prompt_utils.py new file mode 100644 index 000000000..98f198c23 --- /dev/null +++ b/lightllm/server/httpserver/prompt_utils.py @@ -0,0 +1,10 @@ +def validate_prompt_text_length(prompt, max_req_total_len): + if not isinstance(prompt, str): + return + + max_prompt_chars = max_req_total_len * 8 + if len(prompt) > max_prompt_chars: + raise ValueError( + f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " + f"the request is rejected before tokenization." + ) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 3d73dee47..796a7a7ed 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -19,6 +19,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue +from lightllm.server.httpserver.prompt_utils import validate_prompt_text_length from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.utils.envs_utils import get_pd_split_max_new_tokens from .pd_selector import create_selector @@ -73,6 +74,7 @@ async def update_req_status(self, upkv_status: PDUpKVStatus): def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs + validate_prompt_text_length(prompt, self.max_req_total_len) prompt_ids = self.tokenizer.encode(prompt, None, **kwargs) image_tokens = 0 img_count = 0 diff --git a/unit_tests/server/config_server/test_access_log.py b/unit_tests/server/config_server/test_access_log.py new file mode 100644 index 000000000..497823980 --- /dev/null +++ b/unit_tests/server/config_server/test_access_log.py @@ -0,0 +1,14 @@ +from fastapi.testclient import TestClient + + +def test_config_server_emits_access_log(monkeypatch): + from lightllm.server.config_server import api_http + + messages = [] + monkeypatch.setattr(api_http.logger, "info", lambda msg, *args, **kwargs: messages.append(str(msg))) + + with TestClient(api_http.app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert any("GET /health 200" in message for message in messages) diff --git a/unit_tests/server/httpserver/test_prompt_length_guard.py b/unit_tests/server/httpserver/test_prompt_length_guard.py new file mode 100644 index 000000000..f03a59a32 --- /dev/null +++ b/unit_tests/server/httpserver/test_prompt_length_guard.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace + +import pytest + +from lightllm.server.httpserver.manager import HttpServerManager +from lightllm.server.httpserver_for_pd_master.manager import HttpServerManagerForPDMaster + + +class TokenizerMustNotRun: + vocab_size = 32000 + + def encode(self, *args, **kwargs): + raise AssertionError("tokenizer should not run for oversized text prompts") + + +def _fake_manager(): + return SimpleNamespace( + max_req_total_len=4, + tokenizer=TokenizerMustNotRun(), + args=SimpleNamespace(max_image_token_count=1024), + ) + + +def _empty_multimodal_params(): + return SimpleNamespace(images=[], audios=[]) + + +@pytest.mark.parametrize("manager_cls", [HttpServerManager, HttpServerManagerForPDMaster]) +def test_tokens_rejects_oversized_text_prompt_before_tokenization(manager_cls): + prompt = "x" * 33 + + with pytest.raises(ValueError, match="prompt text length 33 exceeds the character limit 32"): + manager_cls.tokens(_fake_manager(), prompt, _empty_multimodal_params(), SimpleNamespace()) From 4360f5ccc85afe81023a8b242fff172bf4602a65 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 17 Jun 2026 22:32:01 +0800 Subject: [PATCH 14/14] chore(router): trim redundant logging comments --- lightllm/server/router/manager.py | 5 ----- lightllm/server/router/stats.py | 8 -------- lightllm/utils/log_utils.py | 1 - 3 files changed, 14 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ff08963a6..e9d90e864 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -233,8 +233,6 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: - # Output-token counting is done in bulk at the print-window boundary - # inside SystemStatusReporter.maybe_print, so the router tick stays cheap. if counter_count % 100 == 0: self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently @@ -341,12 +339,9 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - # Capture finished req stats before filtering for req in self.running_batch.reqs: if not req.shm_infer_released: continue - # Settle any output-token tail produced after the last window boundary, - # so windowed TPS does not lose the req's last tokens. self.status_reporter.discard_req(req) self.status_reporter.on_request_completed( input_len=req.input_len, diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index c414a4386..94556f21c 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -18,16 +18,13 @@ def __init__(self, args, max_total_token_num, dp_size_in_node): self.dp_size_in_node = dp_size_in_node self.status_logger = init_system_status_logger("router") - # Accumulation counters (reset each interval) self.last_print_time = time.time() self.prompt_tokens = 0 self.output_tokens = 0 - # Windowed counters for cache hit (reset each interval) self.window_input_total = 0 self.window_cache_total = 0 - # Global counters (never reset, for lifetime stats) self.global_input_total = 0 self.global_cache_total = 0 self.global_mtp_output_total = 0 @@ -132,7 +129,6 @@ def maybe_print( running = len(running_batch.reqs) if running_batch else 0 queued = req_queue.get_wait_req_num() - # Memory utilization (average across dp) # kv_used: physical KV memory usage (includes prefix cache tree occupancy) # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens kv_used_list = [] @@ -149,11 +145,9 @@ def maybe_print( avg_kv_used = sum(kv_used_list) / len(kv_used_list) avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list) - # Windowed prefix cache hit rate (this interval only) window_cache_hit_rate = ( (self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0 ) - # Global prefix cache hit rate (lifetime) global_cache_hit_rate = ( (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 ) @@ -169,7 +163,6 @@ def maybe_print( f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)", ] - # Avg MTP accepted length (only shown when MTP is active) if self.global_mtp_accepted_total > 0: decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) @@ -193,7 +186,6 @@ def maybe_print( ] logger.debug(" | ".join(debug_parts)) - # Reset windowed counters self.prompt_tokens = 0 self.output_tokens = 0 self.window_input_total = 0 diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index c3057d18f..6df36c3ae 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -15,7 +15,6 @@ _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) -# ANSI color codes _RESET = "\033[0m" _LEVEL_COLORS = { logging.DEBUG: "\033[36m", # cyan