From 88cc50c393c4832f028d052b61d0df4fab2df1e1 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Thu, 18 Jun 2026 12:44:35 +0000
Subject: [PATCH 01/14] fix stream fc for qwen3_coder
---
lightllm/server/function_call_parser.py | 174 +++++++++++++++++++++++-
1 file changed, 169 insertions(+), 5 deletions(-)
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index dfcb2f8d9..6d42ca87a 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -1758,6 +1758,10 @@ def __init__(self):
self.parameter_regex = re.compile(
r"|(?=)|$)", re.DOTALL
)
+ self.parameter_stream_regex = re.compile(
+ r"]*)>(.*?)(|(?=)|$)", re.DOTALL
+ )
+ self.function_start_regex = re.compile(r"]*)>", re.DOTALL)
self._normal_text_buffer = ""
def has_tool_call(self, text: str) -> bool:
@@ -1890,6 +1894,111 @@ def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools
return json.dumps(param_dict, ensure_ascii=False)
+ def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str:
+ prop = param_config.get(param_name, {})
+ return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string"
+
+ def _json_string_prefix(self, value: str, closed: bool) -> str:
+ dumped = json.dumps(value, ensure_ascii=False)
+ return dumped if closed else dumped[:-1]
+
+ def _strip_partial_xml_suffix(self, value: str) -> str:
+ for token in ("", "", self.eot_token):
+ max_len = min(len(value), len(token) - 1)
+ for suffix_len in range(max_len, 0, -1):
+ if token.startswith(value[-suffix_len:]):
+ return value[:-suffix_len]
+ return value
+
+ def _build_streaming_arguments_json(
+ self,
+ func_name: str,
+ partial_body: str,
+ tools: List[Tool],
+ close_object: bool = False,
+ ) -> Optional[str]:
+ """Build a monotonic JSON arguments prefix for XML tool-call streaming."""
+ param_config = self._get_param_config(func_name, tools)
+ parts = ["{"]
+ has_param = False
+
+ for match in self.parameter_stream_regex.finditer(partial_body):
+ param_name = match.group(1).strip()
+ if not param_name:
+ continue
+
+ param_value = match.group(2)
+ closed = match.group(3) == ""
+ if not closed:
+ param_value = self._strip_partial_xml_suffix(param_value)
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ if closed:
+ param_value = param_value[:-1]
+ else:
+ # Keep the template delimiter newline buffered until we know
+ # whether it is part of the value or just precedes .
+ param_value = param_value[:-1]
+
+ if has_param:
+ parts.append(", ")
+ parts.append(json.dumps(param_name, ensure_ascii=False))
+ parts.append(": ")
+
+ param_type = self._get_qwen3_param_type(param_name, param_config)
+ if param_type in ("string", "str", "enum") or param_name not in param_config:
+ parts.append(self._json_string_prefix(param_value, closed))
+ elif closed:
+ converted = self._convert_param_value(param_value, param_name, param_config, func_name)
+ converted_json = json.dumps(converted, ensure_ascii=False)
+ if converted_json.startswith(param_value):
+ parts.append(converted_json)
+ else:
+ parts.append(param_value)
+ else:
+ parts.append(param_value)
+
+ has_param = True
+ if not closed:
+ break
+
+ if not has_param and close_object:
+ return "{}"
+ if not has_param:
+ return None
+
+ if close_object:
+ parts.append("}")
+
+ return "".join(parts)
+
+ def _ensure_qwen3_stream_state(self, tool_index: int) -> None:
+ while len(self.prev_tool_call_arr) <= tool_index:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= tool_index:
+ self.streamed_args_for_tool.append("")
+
+ def _append_qwen3_arguments_delta(self, calls: List[ToolCallItem], tool_index: int, current_args_json: str) -> None:
+ sent_args = self.streamed_args_for_tool[tool_index]
+ if not current_args_json.startswith(sent_args):
+ logger.warning(
+ "Qwen3-Coder streaming arguments are not monotonic for tool index %s; skip delta.",
+ tool_index,
+ )
+ return
+
+ argument_diff = current_args_json[len(sent_args) :]
+ if argument_diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=tool_index,
+ name=None,
+ parameters=argument_diff,
+ )
+ )
+ self.streamed_args_for_tool[tool_index] += argument_diff
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
@@ -1941,23 +2050,78 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
self._buffer = current_text[tool_call_start:]
current_text = self._buffer
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+
+ self._ensure_qwen3_stream_state(self.current_tool_id)
+
+ function_match = self.function_start_regex.search(current_text)
+ if not function_match:
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ func_name = function_match.group(1).strip()
eot_pos = current_text.find(self.eot_token)
+ if func_name not in self._tool_indices:
+ if eot_pos == -1:
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+ logger.warning(f"Model attempted to call undefined function: {func_name}")
+ self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
+ self.current_tool_name_sent = False
+ continue
+
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+
+ function_close_pos = current_text.find("", function_match.end())
+ close_object = function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos)
+ partial_end = function_close_pos if close_object else eot_pos
+ if partial_end == -1:
+ partial_end = len(current_text)
+ partial_body = current_text[function_match.end() : partial_end]
+ current_args_json = self._build_streaming_arguments_json(
+ func_name,
+ partial_body,
+ tools,
+ close_object=close_object,
+ )
+ if current_args_json:
+ self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json)
+
if eot_pos == -1:
return StreamingParseResult(normal_text=normal_text, calls=calls)
complete_block = current_text[: eot_pos + len(self.eot_token)]
func_matches = self.function_regex.findall(complete_block)
- if self.current_tool_id == -1:
- self.current_tool_id = 0
-
for match in func_matches:
func_str = match[0] if match[0] else match[1]
item = self._parse_function_call(func_str, tools)
if item:
- item.tool_index = self.current_tool_id
- calls.append(item)
+ completed_tool_id = self.current_tool_id
+ try:
+ parsed_args = json.loads(item.parameters)
+ except json.JSONDecodeError:
+ parsed_args = {}
+ self.prev_tool_call_arr[completed_tool_id] = {
+ "name": item.name,
+ "arguments": parsed_args,
+ }
+ sent_args = self.streamed_args_for_tool[completed_tool_id]
+ if item.parameters.startswith(sent_args):
+ self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters)
self.current_tool_id += 1
+ self.current_tool_name_sent = False
self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
From a62103d588e6f364b48da99030d21e643485d275 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Thu, 18 Jun 2026 21:58:30 +0800
Subject: [PATCH 02/14] fix(qwen3_coder): correct streaming tool-call argument
parsing
Keep streamed_args_for_tool a byte-exact prefix of json.dumps(arguments) so
the serving-layer stop-time reconciliation can't produce duplicated/invalid
JSON. String values stream incrementally; non-string values are emitted only
once arrives (a partial number/array/bool isn't a prefix of its
json.dumps form). Detect "value still streaming" via the terminator rather
than match position ($ matches before the template's trailing newline).
Also:
- don't crash (IndexError) on >=2 in one block, and
emit a name head for each; flush via _ensure_qwen3_stream_state.
- an undefined first function no longer discards the whole block (a valid
function after it is still emitted).
- treat as an implicit function close so the closing '}' rides
in the same args delta (no separate trailing-'}' delta to double-count).
- drop dead _build_partial_arguments_json; dedup newline-strip / param-type
helpers.
Add unit_tests/server/test_qwen3_coder_stream_fc.py covering these across
chunk boundaries.
---
lightllm/server/function_call_parser.py | 248 ++++++++----------
.../server/test_qwen3_coder_stream_fc.py | 194 ++++++++++++++
2 files changed, 310 insertions(+), 132 deletions(-)
create mode 100644 unit_tests/server/test_qwen3_coder_stream_fc.py
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index 6d42ca87a..093c0c568 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -1786,8 +1786,7 @@ def _convert_param_value(self, value: str, param_name: str, param_config: Dict,
if param_name not in param_config:
return value
- prop = param_config.get(param_name, {})
- param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string"
+ param_type = self._get_qwen3_param_type(param_name, param_config)
if param_type in ("string", "str", "enum"):
return value
@@ -1837,12 +1836,7 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional
except ValueError:
continue
param_name = match[:idx].strip()
- param_value = match[idx + 1 :]
- # Strip leading/trailing newlines from value
- if param_value.startswith("\n"):
- param_value = param_value[1:]
- if param_value.endswith("\n"):
- param_value = param_value[:-1]
+ param_value = self._strip_value_newlines(match[idx + 1 :])
param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name)
@@ -1852,55 +1846,17 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional
parameters=json.dumps(param_dict, ensure_ascii=False),
)
- def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools: List[Tool]) -> Optional[str]:
- """Build the current argument JSON from a partial XML tool-call body."""
- param_matches = self.parameter_regex.findall(partial_body)
- if not param_matches:
- return None
-
- param_config = self._get_param_config(func_name, tools)
- param_dict = {}
- has_visible_value = False
-
- for match in param_matches:
- try:
- idx = match.index(">")
- except ValueError:
- continue
-
- param_name = match[:idx].strip()
- param_value = match[idx + 1 :]
- if param_value.startswith("\n"):
- param_value = param_value[1:]
- if param_value.endswith("\n"):
- param_value = param_value[:-1]
-
- if param_value.strip():
- has_visible_value = True
- elif (
- f"" in partial_body
- and f"{param_value}" in partial_body
- ):
- # Closed empty-string parameter. We can safely emit it.
- has_visible_value = True
- else:
- # Parameter tag is present but its value has not started streaming yet.
- continue
-
- param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name)
-
- if not param_dict and not has_visible_value:
- return None
-
- return json.dumps(param_dict, ensure_ascii=False)
-
def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str:
prop = param_config.get(param_name, {})
return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string"
- def _json_string_prefix(self, value: str, closed: bool) -> str:
- dumped = json.dumps(value, ensure_ascii=False)
- return dumped if closed else dumped[:-1]
+ def _strip_value_newlines(self, value: str) -> str:
+ """Strip the single leading/trailing newline the Qwen3 template wraps each value in."""
+ if value.startswith("\n"):
+ value = value[1:]
+ if value.endswith("\n"):
+ value = value[:-1]
+ return value
def _strip_partial_xml_suffix(self, value: str) -> str:
for token in ("", "", self.eot_token):
@@ -1917,7 +1873,14 @@ def _build_streaming_arguments_json(
tools: List[Tool],
close_object: bool = False,
) -> Optional[str]:
- """Build a monotonic JSON arguments prefix for XML tool-call streaming."""
+ """Build a monotonic JSON arguments prefix for XML tool-call streaming.
+
+ The result is always a byte-exact prefix of json.dumps(final_arguments) so the
+ serving layer (api_openai.py) can reconcile the streamed args at stream end.
+ String values stream character-by-character (a string prefix stays a prefix);
+ non-string values are only emitted once their arrives, because a
+ partial number/array/bool is not guaranteed to be a prefix of its json.dumps form.
+ """
param_config = self._get_param_config(func_name, tools)
parts = ["{"]
has_param = False
@@ -1927,46 +1890,45 @@ def _build_streaming_arguments_json(
if not param_name:
continue
- param_value = match.group(2)
- closed = match.group(3) == ""
- if not closed:
- param_value = self._strip_partial_xml_suffix(param_value)
- if param_value.startswith("\n"):
- param_value = param_value[1:]
- if param_value.endswith("\n"):
- if closed:
- param_value = param_value[:-1]
- else:
- # Keep the template delimiter newline buffered until we know
- # whether it is part of the value or just precedes .
- param_value = param_value[:-1]
+ # The value is complete only when an explicit closed it, or a
+ # sibling follows. Otherwise it is still streaming.
+ # (We can't key off match.end()==len: `$` matches before a trailing newline,
+ # and the template wraps every value in one, which would look "complete".)
+ rest = partial_body[match.end() :]
+ value_open = (
+ match.group(3) != ""
+ and not rest.startswith("")
+ )
if has_param:
parts.append(", ")
parts.append(json.dumps(param_name, ensure_ascii=False))
parts.append(": ")
+ has_param = True
param_type = self._get_qwen3_param_type(param_name, param_config)
- if param_type in ("string", "str", "enum") or param_name not in param_config:
- parts.append(self._json_string_prefix(param_value, closed))
- elif closed:
- converted = self._convert_param_value(param_value, param_name, param_config, func_name)
- converted_json = json.dumps(converted, ensure_ascii=False)
- if converted_json.startswith(param_value):
- parts.append(converted_json)
- else:
- parts.append(param_value)
+ is_string = param_type in ("string", "str", "enum")
+
+ if value_open:
+ # In-progress (and therefore last) parameter.
+ if is_string:
+ value = self._strip_value_newlines(self._strip_partial_xml_suffix(match.group(2)))
+ # Drop the closing quote so the stream stays an extendable prefix.
+ parts.append(json.dumps(value, ensure_ascii=False)[:-1])
+ # Non-string values cannot be emitted as a safe partial prefix, so stop
+ # after the key and wait for the value to close.
+ return "".join(parts)
+
+ value = self._strip_value_newlines(match.group(2))
+ if is_string:
+ parts.append(json.dumps(value, ensure_ascii=False))
else:
- parts.append(param_value)
-
- has_param = True
- if not closed:
- break
+ converted = self._convert_param_value(value, param_name, param_config, func_name)
+ parts.append(json.dumps(converted, ensure_ascii=False))
- if not has_param and close_object:
- return "{}"
if not has_param:
- return None
+ return "{}" if close_object else None
if close_object:
parts.append("}")
@@ -2061,42 +2023,51 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
func_name = function_match.group(1).strip()
eot_pos = current_text.find(self.eot_token)
- if func_name not in self._tool_indices:
- if eot_pos == -1:
- return StreamingParseResult(normal_text=normal_text, calls=calls)
- logger.warning(f"Model attempted to call undefined function: {func_name}")
- self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
- self.current_tool_name_sent = False
- continue
+ func_defined = func_name in self._tool_indices
- if not self.current_tool_name_sent:
- calls.append(
- ToolCallItem(
- tool_index=self.current_tool_id,
- name=func_name,
- parameters="",
+ # Undefined function whose block has not finished yet: wait for more text
+ # (the block may also contain a valid function we shouldn't drop).
+ if not func_defined and eot_pos == -1:
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ if func_defined:
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
)
- )
- self.current_tool_name_sent = True
- self.prev_tool_call_arr[self.current_tool_id] = {
- "name": func_name,
- "arguments": {},
- }
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
- function_close_pos = current_text.find("", function_match.end())
- close_object = function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos)
- partial_end = function_close_pos if close_object else eot_pos
- if partial_end == -1:
- partial_end = len(current_text)
- partial_body = current_text[function_match.end() : partial_end]
- current_args_json = self._build_streaming_arguments_json(
- func_name,
- partial_body,
- tools,
- close_object=close_object,
- )
- if current_args_json:
- self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json)
+ # The function body is complete once we hit either or the
+ # enclosing ; treating eot as an implicit close lets us emit
+ # the closing '}' inside the same args delta (so the serving layer's
+ # stop-time reconciliation sees one delta, not a separate trailing '}').
+ function_close_pos = current_text.find("", function_match.end())
+ if function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos):
+ partial_end = function_close_pos
+ close_object = True
+ elif eot_pos != -1:
+ partial_end = eot_pos
+ close_object = True
+ else:
+ partial_end = len(current_text)
+ close_object = False
+ partial_body = current_text[function_match.end() : partial_end]
+ current_args_json = self._build_streaming_arguments_json(
+ func_name,
+ partial_body,
+ tools,
+ close_object=close_object,
+ )
+ if current_args_json:
+ self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json)
if eot_pos == -1:
return StreamingParseResult(normal_text=normal_text, calls=calls)
@@ -2104,24 +2075,37 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
complete_block = current_text[: eot_pos + len(self.eot_token)]
func_matches = self.function_regex.findall(complete_block)
+ # Flush every completed function in the block. _parse_function_call returns
+ # None for undefined ones, so they are skipped without advancing the index.
for match in func_matches:
func_str = match[0] if match[0] else match[1]
item = self._parse_function_call(func_str, tools)
- if item:
- completed_tool_id = self.current_tool_id
- try:
- parsed_args = json.loads(item.parameters)
- except json.JSONDecodeError:
- parsed_args = {}
- self.prev_tool_call_arr[completed_tool_id] = {
- "name": item.name,
- "arguments": parsed_args,
- }
- sent_args = self.streamed_args_for_tool[completed_tool_id]
- if item.parameters.startswith(sent_args):
- self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters)
- self.current_tool_id += 1
- self.current_tool_name_sent = False
+ if not item:
+ continue
+ completed_tool_id = self.current_tool_id
+ self._ensure_qwen3_stream_state(completed_tool_id)
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=completed_tool_id,
+ name=item.name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ try:
+ parsed_args = json.loads(item.parameters)
+ except json.JSONDecodeError:
+ parsed_args = {}
+ self.prev_tool_call_arr[completed_tool_id] = {
+ "name": item.name,
+ "arguments": parsed_args,
+ }
+ sent_args = self.streamed_args_for_tool[completed_tool_id]
+ if item.parameters.startswith(sent_args):
+ self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters)
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
diff --git a/unit_tests/server/test_qwen3_coder_stream_fc.py b/unit_tests/server/test_qwen3_coder_stream_fc.py
new file mode 100644
index 000000000..f72a642b1
--- /dev/null
+++ b/unit_tests/server/test_qwen3_coder_stream_fc.py
@@ -0,0 +1,194 @@
+"""Unit tests for Qwen3-Coder XML streaming tool-call parsing.
+
+These drive ``Qwen3CoderDetector.parse_streaming_increment`` directly (no server),
+reassembling the per-tool argument string exactly the way ``api_openai.py`` does for
+streamed responses, including the ``finish_reason == "stop"`` reconciliation. The key
+invariant under test: the reassembled arguments are always valid JSON equal to what the
+one-shot ``detect_and_parse`` would produce, for every chunk boundary.
+"""
+
+import json
+import pytest
+
+from lightllm.server.api_models import Function, Tool
+from lightllm.server.function_call_parser import Qwen3CoderDetector
+
+CHUNK_SIZES = [1, 2, 3, 5, 13, 10_000]
+
+
+def _tool(name, properties):
+ return Tool(
+ type="function",
+ function=Function(name=name, description="", parameters={"type": "object", "properties": properties}),
+ )
+
+
+def _stream_and_reassemble(text, tools, chunk):
+ """Feed ``text`` to the detector in fixed-size chunks and rebuild the client-visible
+ tool calls the way api_openai.py stream_results does (stop-rewrite on the last chunk)."""
+ det = Qwen3CoderDetector()
+ chunks = [text[i : i + chunk] for i in range(0, len(text), chunk)]
+ per_tool = {}
+ for ci, piece in enumerate(chunks):
+ result = det.parse_streaming_increment(piece, tools)
+ is_last = ci == len(chunks) - 1
+ for call in result.calls:
+ ti = call.tool_index
+ per_tool.setdefault(ti, {"name": None, "args": ""})
+ if call.name is not None:
+ per_tool[ti]["name"] = call.name
+ params = call.parameters
+ if is_last and params:
+ # Mirror api_openai.py:559-575 (REPLACE semantics).
+ latest_delta_len = len(params)
+ expected = json.dumps(det.prev_tool_call_arr[ti].get("arguments", {}), ensure_ascii=False)
+ actual = det.streamed_args_for_tool[ti]
+ if latest_delta_len > 0:
+ actual = actual[:-latest_delta_len]
+ params = expected.replace(actual, "", 1)
+ if params:
+ per_tool[ti]["args"] += params
+ return det, per_tool
+
+
+def _assert_tool_calls(text, tools, expected, chunk_sizes=CHUNK_SIZES):
+ """expected: {tool_index: (name, args_dict)}."""
+ for chunk in chunk_sizes:
+ det, per_tool = _stream_and_reassemble(text, tools, chunk)
+ assert len(per_tool) == len(expected), f"chunk={chunk}: tool count {len(per_tool)} != {len(expected)}"
+ for ti, (name, args) in expected.items():
+ got = per_tool[ti]
+ assert got["name"] == name, f"chunk={chunk}: tool {ti} name {got['name']!r} != {name!r}"
+ parsed = json.loads(got["args"]) # must be valid JSON
+ assert parsed == args, f"chunk={chunk}: tool {ti} args {parsed!r} != {args!r}"
+
+
+@pytest.mark.parametrize("chunk", CHUNK_SIZES)
+def test_single_string_param(chunk):
+ text = (
+ "\n\n\n"
+ "San Francisco\n\n\n"
+ )
+ _assert_tool_calls(
+ text,
+ [_tool("get_weather", {"location": {"type": "string"}})],
+ {0: ("get_weather", {"location": "San Francisco"})},
+ [chunk],
+ )
+
+
+def test_array_param_compact_spacing():
+ # Regression: a non-string value whose raw text ("[1,2]") differs from json.dumps
+ # ("[1, 2]") used to break the streamed-args prefix invariant -> duplicated/invalid JSON.
+ text = "\n\n\n[1,2]\n\n\n"
+ _assert_tool_calls(text, [_tool("calc", {"nums": {"type": "array"}})], {0: ("calc", {"nums": [1, 2]})})
+
+
+def test_number_param_reformatted():
+ # "1.0" is parsed to int 1 and json.dumps'd as "1"; the stream must agree.
+ text = "\n\n\n1.0\n\n\n"
+ _assert_tool_calls(text, [_tool("calc", {"v": {"type": "number"}})], {0: ("calc", {"v": 1})})
+
+
+def test_boolean_param():
+ text = "\n\n\ntrue\n\n\n"
+ _assert_tool_calls(text, [_tool("set", {"flag": {"type": "boolean"}})], {0: ("set", {"flag": True})})
+
+
+def test_object_param():
+ text = '\n\n\n{"a":1,"b":[2,3]}\n\n\n'
+ _assert_tool_calls(text, [_tool("f", {"cfg": {"type": "object"}})], {0: ("f", {"cfg": {"a": 1, "b": [2, 3]}})})
+
+
+def test_two_params_mixed_types():
+ text = (
+ "\n\n\nNYC\n\n"
+ "\n3\n\n\n"
+ )
+ _assert_tool_calls(
+ text,
+ [_tool("f", {"city": {"type": "string"}, "days": {"type": "integer"}})],
+ {0: ("f", {"city": "NYC", "days": 3})},
+ )
+
+
+def test_multiline_string_value():
+ text = (
+ "\n\n\nline1\nline2\nline3\n" "\n\n"
+ )
+ _assert_tool_calls(text, [_tool("f", {"code": {"type": "string"}})], {0: ("f", {"code": "line1\nline2\nline3"})})
+
+
+def test_string_with_json_special_chars():
+ text = '\n\n\nsay "hi"\\path\n\n\n'
+ _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": 'say "hi"\\path'})})
+
+
+def test_empty_string_value():
+ text = "\n\n\n\n\n\n"
+ _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": ""})})
+
+
+def test_no_param_function():
+ text = "\n\n\n"
+ _assert_tool_calls(text, [_tool("ping", {})], {0: ("ping", {})})
+
+
+def test_two_separate_tool_call_blocks():
+ text = (
+ "\n\n\nhi\n\n\n\n"
+ "\n\n\nyo\n\n\n"
+ )
+ _assert_tool_calls(
+ text,
+ [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})],
+ {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})},
+ )
+
+
+def test_two_functions_in_one_block():
+ # Regression: used to raise IndexError on the second function in a single block.
+ text = (
+ "\n\n\nhi\n\n\n"
+ "\n\nyo\n\n\n"
+ )
+ _assert_tool_calls(
+ text,
+ [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})],
+ {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})},
+ )
+
+
+def test_undefined_then_valid_in_same_block():
+ # Regression: an undefined first function used to discard the whole block, dropping
+ # the valid call that followed it.
+ text = (
+ "\n\n\nhi\n\n\n"
+ "\n\nyo\n\n\n"
+ )
+ _assert_tool_calls(text, [_tool("valid", {"y": {"type": "string"}})], {0: ("valid", {"y": "yo"})})
+
+
+def test_truncated_call_missing_function_close():
+ # Regression: a typed value with no before used to leave the
+ # streamed args unterminated (missing closing brace).
+ text = "\n\n\n0.50\n\n"
+ _assert_tool_calls(text, [_tool("calc", {"x": {"type": "number"}})], {0: ("calc", {"x": 0.5})})
+
+
+def test_streaming_matches_non_stream():
+ # The reassembled streamed args must equal the one-shot detect_and_parse output.
+ tools = [_tool("f", {"city": {"type": "string"}, "n": {"type": "integer"}, "tags": {"type": "array"}})]
+ text = (
+ "\n\n\nLondon\n\n"
+ '\n7\n\n\n["a","b"]\n\n\n'
+ )
+ oneshot = Qwen3CoderDetector().detect_and_parse(text, tools)
+ expected_args = json.loads(oneshot.calls[0].parameters)
+ _assert_tool_calls(text, tools, {0: ("f", expected_args)})
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.exit(pytest.main([__file__, "-v"]))
From 7093bdadf88cd4227f7492f148e6d793608473b0 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 6 May 2026 11:26:32 +0800
Subject: [PATCH 03/14] refactor(logging): colored levels, windowed cache
stats, quieter per-request logs
- Add ANSI color codes to log level names (TTY only; plain in files)
- Introduce SystemStatusReporter with windowed prefix-cache hit rate
alongside the global rate, plus a more compact status line
- Drop gunicorn --access-logfile flags (FastAPI middleware now handles it)
- Remove duplicate _ACCESS_LOG_STATUS_COLORS declaration in api_http.py
- Downgrade noisy per-request / per-batch progress logs from INFO to DEBUG
- Fix flake8 F841 (unused exception variable) in detokenization manager
---
lightllm/common/basemodel/cuda_graph.py | 15 ++-
lightllm/common/triton_utils/autotuner.py | 6 +-
lightllm/server/api_http.py | 1 -
lightllm/server/api_start.py | 6 -
lightllm/server/detokenization/manager.py | 9 +-
lightllm/server/httpserver/manager.py | 8 +-
.../httpserver_for_pd_master/manager.py | 2 +-
lightllm/server/router/batch.py | 8 +-
lightllm/server/router/manager.py | 58 +++++----
lightllm/server/router/stats.py | 123 +++++++++++++++++-
lightllm/utils/log_utils.py | 72 ++++++----
11 files changed, 228 insertions(+), 80 deletions(-)
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index 782150661..7a24fd17f 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -4,6 +4,7 @@
import bisect
import triton
from typing import Optional
+from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager
@@ -197,7 +198,11 @@ def warmup(self, model):
model: TpPartBaseModel = model
# decode cuda graph init
- for batch_size in self.cuda_graph_batch_sizes[::-1]:
+ progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
+ for batch_size in progress_bar:
+ avail_mem, _ = torch.cuda.mem_get_info()
+ avail_mem_gb = avail_mem / (1024 ** 3)
+ progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
@@ -252,7 +257,13 @@ def warmup_overlap(self, model):
model: TpPartBaseModel = model
- for batch_size in self.cuda_graph_batch_sizes[::-1]:
+ progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
+ for batch_size in progress_bar:
+ avail_mem, _ = torch.cuda.mem_get_info()
+ avail_mem_gb = avail_mem / (1024 ** 3)
+ progress_bar.set_description(
+ f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
+ )
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py
index c62a2572f..3cbc5dc0f 100644
--- a/lightllm/common/triton_utils/autotuner.py
+++ b/lightllm/common/triton_utils/autotuner.py
@@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
- logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
+ logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
@@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
- logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
+ logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")
- logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
+ logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")
def _mutate_args_clone(self, args, kwargs):
origin_list = []
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 270e2a8cf..e127f0931 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -116,7 +116,6 @@ def set_args(self, args: StartArgs):
app = FastAPI()
g_objs.app = app
-_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_RESET = "\033[0m"
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 3cf431d65..d191ee394 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -495,8 +495,6 @@ def normal_or_p_d_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
- "--access-logfile",
- "-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
@@ -566,8 +564,6 @@ def pd_master_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
- "--access-logfile",
- "-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
@@ -662,8 +658,6 @@ def config_server_start(args):
f"{args.config_server_host}:{args.config_server_port}",
"--log-level",
"info",
- "--access-logfile",
- "-",
"--error-logfile",
"-",
"lightllm.server.config_server.api_http:app",
diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py
index 8c213914c..ae7103700 100644
--- a/lightllm/server/detokenization/manager.py
+++ b/lightllm/server/detokenization/manager.py
@@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()
- logger.info(
+ logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)
@@ -76,7 +76,10 @@ def handle_loop(self):
for _ in range(recv_max_count):
recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
assert isinstance(recv_obj, GroupReqIndexes)
- self._add_new_group_req_index(recv_obj=recv_obj)
+ try:
+ self._add_new_group_req_index(recv_obj=recv_obj)
+ except Exception:
+ logger.exception("add new group req index has exception")
# 当队列中存在较多的请求时,将一次接受的数量上调
recv_max_count = min(int(recv_max_count * 1.3), 256)
@@ -160,7 +163,7 @@ def remove_finished_reqs(self):
for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
- logger.info(f"detoken release req id {decode_req.req.request_id}")
+ logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 0f1b87311..374b089da 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -506,7 +506,7 @@ async def _log_req_header(self, request_headers, group_request_id: int):
x_session_id = request_headers.get("X-Session-Id", "")
format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
- logger.info(
+ logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
@@ -744,7 +744,7 @@ async def _wait_to_token_package(
(out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1
)
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
- logger.info(
+ logger.debug(
f"X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_start_time} "
f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms "
@@ -856,8 +856,8 @@ async def recycle_resource_loop(self):
if req_status is None:
continue
- logger.info(
- f"left req id {req_status.group_req_objs.group_req_id}"
+ logger.debug(
+ f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py
index 104da9f26..3d73dee47 100644
--- a/lightllm/server/httpserver_for_pd_master/manager.py
+++ b/lightllm/server/httpserver_for_pd_master/manager.py
@@ -197,7 +197,7 @@ async def _log_req_header(self, request: Request, group_request_id: int):
x_request_id = request.headers.get("X-Request-Id", "")
x_session_id = request.headers.get("X-Session-Id", "")
format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
- logger.info(
+ logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py
index 24d0b9b82..34902d812 100644
--- a/lightllm/server/router/batch.py
+++ b/lightllm/server/router/batch.py
@@ -3,7 +3,6 @@
from typing import Dict, List, Optional, Tuple, Union
from lightllm.server.core.objs import ShmReqManager, Req
from lightllm.utils.log_utils import init_logger
-from .stats import RouterStatics
logger = init_logger(__name__)
@@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num
- def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics):
+ def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
- logger.info(f"router release req id {req.request_id}")
- if not req.is_aborted:
- router_statics.update(req.candetoken_out_len)
-
+ logger.debug(f"router release req id {req.request_id}")
shm_req_manager.put_back_req_obj(req)
req = None
else:
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index dfb886660..f88ef53eb 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -16,6 +16,7 @@
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
+from .stats import SystemStatusReporter
from lightllm.server.core.objs.io_objs import (
GroupReqIndexes,
AbortedReqCmd,
@@ -25,7 +26,7 @@
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
-from lightllm.utils.log_utils import init_logger, log_time_ready
+from lightllm.utils.log_utils import init_logger
from lightllm.utils.profiler import ProfilerCmd
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
@@ -68,6 +69,7 @@ def __init__(self, args: StartArgs):
self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager()
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
+ self.status_reporter = None
# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
@@ -197,6 +199,11 @@ async def wait_to_model_ready(self):
)
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
+ self.status_reporter = SystemStatusReporter(
+ args=self.args,
+ max_total_token_num=self.max_total_token_num,
+ dp_size_in_node=self.dp_size_in_node,
+ )
if self.args.run_mode == "prefill":
from lightllm.server.router.model_infer.mode_backend.pd.prefill_node_impl import (
@@ -226,24 +233,9 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
+ # Count output tokens (each running req produces ~1 token per decode step)
+ self.status_reporter.count_output_tokens(len(self.running_batch.reqs))
if counter_count % 100 == 0:
- for dp_index in range(self.dp_size_in_node):
- token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num
- token_ratio2 = (
- self.max_total_token_num
- - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
- ) / self.max_total_token_num
- d_i = dp_index
- estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
- paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
- logger.debug(
- f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
- f"dp_i {d_i} paused req num: {paused_req_num} \n"
- f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
- f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
- f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token"
- )
- logger.debug(self.router_statics.log_str())
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode)
@@ -265,11 +257,15 @@ async def loop_for_fwd(
self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0)
self.metric_client.gauge_set("lightllm_queue_size", 0.0)
self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0)
- # 60s print once
- if log_time_ready("token_load_info", 60):
- for dp_i in range(self.dp_size_in_node):
- estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i)
- logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n")
+
+ self.status_reporter.maybe_print(
+ running_batch=self.running_batch,
+ req_queue=self.req_queue,
+ read_only_statics_mem_manager=self.read_only_statics_mem_manager,
+ paused_req_num=self._get_paused_req_num(),
+ radix_cache_client=self.radix_cache_client,
+ disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache,
+ )
await asyncio.sleep(self._get_schedule_time_interval())
@@ -300,6 +296,7 @@ async def _step(self):
async def _add_batch(self, batch: Batch):
# 添加新请求
+ self.status_reporter.count_prompt_tokens(batch.input_tokens())
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
while not self.shm_reqs_io_buffer.is_empty():
await asyncio.sleep(0.001)
@@ -344,7 +341,16 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):
def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
- self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics)
+ # Capture finished req stats before filtering
+ for req in self.running_batch.reqs:
+ if req.shm_infer_released:
+ self.status_reporter.on_request_completed(
+ input_len=req.input_len,
+ output_len=req.shm_cur_output_len,
+ cache_len=req.prompt_cache_len,
+ mtp_accepted=req.mtp_accepted_token_num,
+ )
+ self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
return
@@ -416,7 +422,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
req._router_stop_str_matched = False
req_group.append(req)
- logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
+ logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s")
self.req_queue.extend(req_group)
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
return
@@ -431,6 +437,8 @@ def _generate_new_batch(self):
logger.info(f"generate new batch, {new_batch.simple_log()}")
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
+ if self.schedule_new_batch is not None:
+ logger.debug(f"gen new batch, {self.schedule_new_batch.simple_log()}")
return
def _multinode_tp_generate_new_batch(self):
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index b715c5bcb..85548d913 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -1,7 +1,126 @@
-from lightllm.utils.log_utils import init_logger
+import time
+import logging
from lightllm.server.core.objs import StartArgs
+from lightllm.utils.log_utils import init_system_status_logger
-logger = init_logger(__name__)
+logger = logging.getLogger(__name__)
+
+
+class SystemStatusReporter:
+ def __init__(self, args, max_total_token_num, dp_size_in_node):
+ self.enabled = not args.disable_log_stats
+ self.interval = max(5, args.log_stats_interval)
+ if args.log_stats_interval < 5:
+ logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s")
+ self.max_total_token_num = max_total_token_num
+ self.dp_size_in_node = dp_size_in_node
+ self.status_logger = init_system_status_logger("router")
+
+ # Accumulation counters (reset each interval)
+ self.last_print_time = time.time()
+ self.prompt_tokens = 0
+ self.output_tokens = 0
+
+ # Windowed counters for cache hit (reset each interval)
+ self.window_input_total = 0
+ self.window_cache_total = 0
+
+ # Global counters (never reset, for lifetime stats)
+ self.global_input_total = 0
+ self.global_cache_total = 0
+ self.global_mtp_output_total = 0
+ self.global_mtp_accepted_total = 0
+
+ def count_prompt_tokens(self, num_tokens: int):
+ if self.enabled:
+ self.prompt_tokens += num_tokens
+
+ def count_output_tokens(self, num_tokens: int):
+ if self.enabled:
+ self.output_tokens += num_tokens
+
+ def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int):
+ if self.enabled:
+ self.window_input_total += input_len
+ self.window_cache_total += cache_len
+ self.global_input_total += input_len
+ self.global_cache_total += cache_len
+ self.global_mtp_output_total += output_len
+ self.global_mtp_accepted_total += mtp_accepted
+
+ def maybe_print(
+ self,
+ running_batch,
+ req_queue,
+ read_only_statics_mem_manager,
+ paused_req_num=0,
+ radix_cache_client=None,
+ disable_dynamic_prompt_cache=False,
+ ):
+ if not self.enabled:
+ return
+ now = time.time()
+ elapsed = now - self.last_print_time
+ if elapsed < self.interval:
+ return
+
+ total_tps = (self.prompt_tokens + self.output_tokens) / elapsed
+ input_tps = self.prompt_tokens / elapsed
+ output_tps = self.output_tokens / elapsed
+
+ running = len(running_batch.reqs) if running_batch else 0
+ queued = req_queue.get_wait_req_num()
+
+ # Memory utilization (average across dp)
+ # kv_used: physical KV memory usage (includes prefix cache tree occupancy)
+ # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens
+ kv_used_list = []
+ kv_used_no_cache_list = []
+ for dp_i in range(self.dp_size_in_node):
+ unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i)
+ used = self.max_total_token_num - unrefed
+ kv_used_list.append(used / self.max_total_token_num)
+ if not disable_dynamic_prompt_cache and radix_cache_client is not None:
+ cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i)
+ kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num)
+ else:
+ kv_used_no_cache_list.append(used / self.max_total_token_num)
+ avg_kv_used = sum(kv_used_list) / len(kv_used_list)
+ avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list)
+
+ # Windowed prefix cache hit rate (this interval only)
+ window_cache_hit_rate = (
+ (self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0
+ )
+ # Global prefix cache hit rate (lifetime)
+ global_cache_hit_rate = (
+ (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
+ )
+
+ kv_pct = avg_kv_used * 100
+ kv_pct_no_cache = avg_kv_used_no_cache * 100
+
+ # Avg MTP accepted length (only shown when MTP is active)
+ mtp_suffix = ""
+ if self.global_mtp_accepted_total > 0:
+ decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total
+ avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1)
+ mtp_suffix = f", MTP {avg_mtp_len:.2f}"
+
+ self.status_logger.info(
+ f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), "
+ f"REQ {running}run, {queued}wait, {paused_req_num}pause, "
+ f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), "
+ f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)"
+ f"{mtp_suffix}"
+ )
+
+ # Reset windowed counters
+ self.prompt_tokens = 0
+ self.output_tokens = 0
+ self.window_input_total = 0
+ self.window_cache_total = 0
+ self.last_print_time = now
class RouterStatics:
diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py
index f15309d5c..c3057d18f 100644
--- a/lightllm/utils/log_utils.py
+++ b/lightllm/utils/log_utils.py
@@ -4,24 +4,41 @@
import logging
import sys
import os
-import time
from typing import Optional
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
-_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug")
+_STATUS_FORMAT = "%(levelname)s [%(asctime)s] %(message)s"
+
+_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info")
_LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0)
_LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None)
+# ANSI color codes
+_RESET = "\033[0m"
+_LEVEL_COLORS = {
+ logging.DEBUG: "\033[36m", # cyan
+ logging.INFO: "\033[32m", # green
+ logging.WARNING: "\033[33m", # yellow
+ logging.ERROR: "\033[31m", # red
+ logging.CRITICAL: "\033[1;31m", # bold red
+}
+
class NewLineFormatter(logging.Formatter):
- """Adds logging prefix to newlines to align multi-line messages."""
+ """Adds logging prefix to newlines to align multi-line messages, with optional color on levelname."""
- def __init__(self, fmt, datefmt=None):
+ def __init__(self, fmt, datefmt=None, use_color=False):
logging.Formatter.__init__(self, fmt, datefmt)
+ self.use_color = use_color
def format(self, record):
+ if self.use_color:
+ color = _LEVEL_COLORS.get(record.levelno, "")
+ if color:
+ record = logging.makeLogRecord(record.__dict__)
+ record.levelname = color + record.levelname + _RESET
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
@@ -39,7 +56,9 @@ def _setup_logger():
_root_logger.setLevel(_LOG_LEVEL)
global _default_handler
global _default_file_handler
- fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
+ _use_color = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
+ color_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=_use_color)
+ plain_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False)
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
@@ -55,10 +74,10 @@ def _setup_logger():
_root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}")
_default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log")
_default_file_handler.setLevel(_LOG_LEVEL)
- _default_file_handler.setFormatter(fmt)
+ _default_file_handler.setFormatter(plain_fmt)
_root_logger.addHandler(_default_file_handler)
- _default_handler.setFormatter(fmt)
+ _default_handler.setFormatter(color_fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
@@ -89,29 +108,28 @@ def init_logger(name: str):
_root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}")
_inference_log_file_handler[pid] = logging.FileHandler(_LOG_DIR + f"/process.{pid}.log")
_inference_log_file_handler[pid].setLevel(_LOG_LEVEL)
- _inference_log_file_handler[pid].setFormatter(NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT))
+ _inference_log_file_handler[pid].setFormatter(
+ NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False)
+ )
_root_logger.addHandler(_inference_log_file_handler[pid])
logger.addHandler(_inference_log_file_handler[pid])
logger.propagate = False
return logger
-_log_time_mark_dict = {}
-
-
-def log_time_ready(mark_name, time_count: int):
- """
- time_count 间隔时间超过多少s调用该函数会返回True,否则返回False
- 用于控制一些日志输出的频率
- """
- global _log_time_mark_dict
-
- if mark_name not in _log_time_mark_dict:
- _log_time_mark_dict[mark_name] = time.time()
- return False
- cur_time_mark = time.time()
- if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count:
- _log_time_mark_dict[mark_name] = cur_time_mark
- return True
- else:
- return False
+def init_system_status_logger(name: str):
+ logger = logging.getLogger(f"lightllm.status.{name}")
+ if not logger.handlers:
+ logger.setLevel(logging.INFO)
+ fmt = logging.Formatter(_STATUS_FORMAT, datefmt=_DATE_FORMAT)
+ handler = logging.StreamHandler(sys.stdout)
+ handler.flush = sys.stdout.flush
+ handler.setFormatter(fmt)
+ logger.addHandler(handler)
+ if _LOG_DIR is not None:
+ file_handler = logging.FileHandler(os.path.join(_LOG_DIR, f"status.{name}.log"))
+ file_handler.setLevel(logging.INFO)
+ file_handler.setFormatter(fmt)
+ logger.addHandler(file_handler)
+ logger.propagate = False
+ return logger
From c7502e4fe3930857b8682622ec710abe8da27e39 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 6 May 2026 11:26:51 +0800
Subject: [PATCH 04/14] fix(httpserver): reject oversized prompts and translate
ValueError to 400
Reject prompts whose character length exceeds max_req_total_len * 8
before tokenization, so a long string can no longer reach the tokenizer
and stall the loop. The raised ValueError is caught one level up: log it
at WARNING, release any held multimodal resources, abort the in-flight
group request, and re-raise so the API layer (which already maps
ValueError to HTTP 400) returns a graceful error to the client instead
of a 500.
---
lightllm/server/httpserver/manager.py | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 374b089da..945184340 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -317,6 +317,13 @@ async def generate(
# 用于等待 pd_master 下发的交换信息
pd_event: asyncio.Event = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:
+ if isinstance(prompt, str):
+ max_prompt_chars = self.max_req_total_len * 8
+ if len(prompt) > max_prompt_chars:
+ raise ValueError(
+ f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
+ f"the request is rejected before tokenization."
+ )
start_time = time.time()
request_headers = request.headers if request is not None else {}
@@ -467,6 +474,12 @@ async def generate(
yield sub_req_id, request_output, metadata, finish_status
+ except ValueError as e:
+ logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}")
+ if group_request_id not in self.req_id_to_out_inf:
+ await self._release_multimodal_resources(multimodal_params)
+ await self.abort(group_request_id)
+ raise e
except (ClientDisconnected, Exception) as e:
logger.warning(f"group_request_id: {group_request_id} has exception {str(e)}")
From de49c0e71ce227626ca43bbd32d76bd55bbadd25 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 6 May 2026 13:26:35 +0800
Subject: [PATCH 05/14] fix(router): restore router_statics.update() on req
completion
The earlier refactor that moved the finished-req loop out of
Batch.filter_out_finished_req into Router._filter_reqs_from_running_batch
forgot to keep the router_statics.update(candetoken_out_len) call,
freezing ema_req_out_len at its initial value. Multiple schedulers
(chunked_prefill, beam, pd_decode, nixl_pd) read that EMA for KV-budget
estimation, so leaving it stale degraded scheduling accuracy.
---
lightllm/server/router/manager.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index f88ef53eb..62d82e072 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -350,6 +350,7 @@ def _filter_reqs_from_running_batch(self):
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
+ self.router_statics.update(req.candetoken_out_len)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
From eafb4a8cd2177e8d283ea23690f1b6a66c7c033c Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 6 May 2026 15:28:54 +0800
Subject: [PATCH 06/14] fix(router): output TPS via per-req deltas, skip
aborted reqs in stats
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Two correctness fixes flagged in PR review:
1. count_output_tokens(len(running_batch.reqs)) once per router loop is
wrong — the router loop polls on schedule_time_interval, decoupled
from inference, so this overcounts when the loop is faster than
decode and undercounts when slower, and includes paused/prefill-only
reqs. Track shm_cur_output_len per request and accumulate the delta
each tick (with a tail settlement when the req is filtered out so we
don't lose its last tokens to the post-final-tick window).
2. on_request_completed() and router_statics.update() now both run for
aborted requests, whose candetoken_out_len is a short partial value.
Restore the prior `if not req.is_aborted` guard so disconnects don't
bias the output-length EMA used by KV-budget estimators.
---
lightllm/server/router/manager.py | 44 ++++++++++++++++++++++++-------
1 file changed, 34 insertions(+), 10 deletions(-)
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index 62d82e072..6dfd9b102 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -70,6 +70,9 @@ def __init__(self, args: StartArgs):
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
self.status_reporter = None
+ # Track shm_cur_output_len per running request to compute per-tick deltas
+ # for accurate output TPS regardless of router schedule interval.
+ self._req_last_output_len: Dict[int, int] = {}
# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
@@ -233,8 +236,18 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
- # Count output tokens (each running req produces ~1 token per decode step)
- self.status_reporter.count_output_tokens(len(self.running_batch.reqs))
+ # Count output tokens via per-request shm_cur_output_len deltas, since the
+ # router loop runs on schedule_time_interval and len(reqs) is not a per-step
+ # token count.
+ new_output_tokens = 0
+ for req in self.running_batch.reqs:
+ cur_out_len = req.shm_cur_output_len
+ prev_out_len = self._req_last_output_len.get(req.request_id, 0)
+ if cur_out_len > prev_out_len:
+ new_output_tokens += cur_out_len - prev_out_len
+ self._req_last_output_len[req.request_id] = cur_out_len
+ if new_output_tokens:
+ self.status_reporter.count_output_tokens(new_output_tokens)
if counter_count % 100 == 0:
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
@@ -343,14 +356,25 @@ def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
# Capture finished req stats before filtering
for req in self.running_batch.reqs:
- if req.shm_infer_released:
- self.status_reporter.on_request_completed(
- input_len=req.input_len,
- output_len=req.shm_cur_output_len,
- cache_len=req.prompt_cache_len,
- mtp_accepted=req.mtp_accepted_token_num,
- )
- self.router_statics.update(req.candetoken_out_len)
+ if not req.shm_infer_released:
+ continue
+ # Settle any output-token delta produced after the last router tick
+ # so windowed TPS does not lose the request's tail tokens.
+ cur_out_len = req.shm_cur_output_len
+ prev_out_len = self._req_last_output_len.pop(req.request_id, 0)
+ if cur_out_len > prev_out_len:
+ self.status_reporter.count_output_tokens(cur_out_len - prev_out_len)
+ # Aborted/disconnected requests can leave a partial output_len that
+ # would bias the EMA toward shorter generations; skip them.
+ if req.is_aborted:
+ continue
+ self.status_reporter.on_request_completed(
+ input_len=req.input_len,
+ output_len=cur_out_len,
+ cache_len=req.prompt_cache_len,
+ mtp_accepted=req.mtp_accepted_token_num,
+ )
+ self.router_statics.update(req.candetoken_out_len)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
From f203f4b0e684dc8dd693389ee10fd819e3f2d028 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 6 May 2026 16:05:59 +0800
Subject: [PATCH 07/14] perf(router): sweep output-token deltas once per print
interval
Move the per-running-req shm_cur_output_len delta tracking from the
router tick (~33 Hz) into SystemStatusReporter.maybe_print, which only
runs once per log_stats_interval (>= 5s). The reporter now owns the
per-req snapshot dict and exposes discard_req(req) for tail settlement
when a req leaves the running batch, so the router loop's hot path no
longer walks the batch every schedule cycle. Output TPS accuracy is
unchanged: still based on real shm_cur_output_len deltas, with tail
tokens settled at completion.
---
lightllm/server/router/manager.py | 28 ++++++----------------------
lightllm/server/router/stats.py | 28 +++++++++++++++++++++++++---
2 files changed, 31 insertions(+), 25 deletions(-)
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index 6dfd9b102..15e05afb7 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -70,9 +70,6 @@ def __init__(self, args: StartArgs):
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
self.status_reporter = None
- # Track shm_cur_output_len per running request to compute per-tick deltas
- # for accurate output TPS regardless of router schedule interval.
- self._req_last_output_len: Dict[int, int] = {}
# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
@@ -236,18 +233,8 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
- # Count output tokens via per-request shm_cur_output_len deltas, since the
- # router loop runs on schedule_time_interval and len(reqs) is not a per-step
- # token count.
- new_output_tokens = 0
- for req in self.running_batch.reqs:
- cur_out_len = req.shm_cur_output_len
- prev_out_len = self._req_last_output_len.get(req.request_id, 0)
- if cur_out_len > prev_out_len:
- new_output_tokens += cur_out_len - prev_out_len
- self._req_last_output_len[req.request_id] = cur_out_len
- if new_output_tokens:
- self.status_reporter.count_output_tokens(new_output_tokens)
+ # Output-token counting is done in bulk at the print-window boundary
+ # inside SystemStatusReporter.maybe_print, so the router tick stays cheap.
if counter_count % 100 == 0:
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
@@ -358,19 +345,16 @@ def _filter_reqs_from_running_batch(self):
for req in self.running_batch.reqs:
if not req.shm_infer_released:
continue
- # Settle any output-token delta produced after the last router tick
- # so windowed TPS does not lose the request's tail tokens.
- cur_out_len = req.shm_cur_output_len
- prev_out_len = self._req_last_output_len.pop(req.request_id, 0)
- if cur_out_len > prev_out_len:
- self.status_reporter.count_output_tokens(cur_out_len - prev_out_len)
+ # Settle any output-token tail produced after the last window boundary,
+ # so windowed TPS does not lose the req's last tokens.
+ self.status_reporter.discard_req(req)
# Aborted/disconnected requests can leave a partial output_len that
# would bias the EMA toward shorter generations; skip them.
if req.is_aborted:
continue
self.status_reporter.on_request_completed(
input_len=req.input_len,
- output_len=cur_out_len,
+ output_len=req.shm_cur_output_len,
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index 85548d913..f6db924b5 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -1,5 +1,6 @@
import time
import logging
+from typing import Dict
from lightllm.server.core.objs import StartArgs
from lightllm.utils.log_utils import init_system_status_logger
@@ -31,13 +32,23 @@ def __init__(self, args, max_total_token_num, dp_size_in_node):
self.global_mtp_output_total = 0
self.global_mtp_accepted_total = 0
+ # Per-req shm_cur_output_len snapshot at the previous window boundary,
+ # used to compute the windowed output-token count without per-tick scans.
+ self._req_last_output_len: Dict[int, int] = {}
+
def count_prompt_tokens(self, num_tokens: int):
if self.enabled:
self.prompt_tokens += num_tokens
- def count_output_tokens(self, num_tokens: int):
- if self.enabled:
- self.output_tokens += num_tokens
+ def discard_req(self, req):
+ """Settle a finished/aborted req's tail output tokens (those produced after the last
+ window-boundary sweep) and drop its tracking entry."""
+ if not self.enabled:
+ return
+ cur_out_len = req.shm_cur_output_len
+ prev_out_len = self._req_last_output_len.pop(req.request_id, 0)
+ if cur_out_len > prev_out_len:
+ self.output_tokens += cur_out_len - prev_out_len
def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int):
if self.enabled:
@@ -64,6 +75,17 @@ def maybe_print(
if elapsed < self.interval:
return
+ # Single bulk sweep at the window boundary: account for output tokens produced
+ # by every still-running req since the previous boundary, and refresh their
+ # snapshots. Reqs that finished in this window already settled via discard_req.
+ if running_batch is not None:
+ for req in running_batch.reqs:
+ cur_out_len = req.shm_cur_output_len
+ prev_out_len = self._req_last_output_len.get(req.request_id, 0)
+ if cur_out_len > prev_out_len:
+ self.output_tokens += cur_out_len - prev_out_len
+ self._req_last_output_len[req.request_id] = cur_out_len
+
total_tps = (self.prompt_tokens + self.output_tokens) / elapsed
input_tps = self.prompt_tokens / elapsed
output_tps = self.output_tokens / elapsed
From 064081c2d385eb9cba5cd26d2f454609858654de Mon Sep 17 00:00:00 2001
From: sufubao
Date: Sat, 9 May 2026 11:46:40 +0800
Subject: [PATCH 08/14] fix(httpserver,router): defensive group_request_id
init; reorder is_aborted skip
- httpserver: initialize group_request_id=None so the ValueError except
handler does not hit UnboundLocalError when the oversized-prompt guard
raises before alloc_req_id.
- router: move the is_aborted skip after on_request_completed so aborted
reqs still update completion stats, but do not pollute the router_statics
EMA with their truncated output_len.
---
lightllm/server/httpserver/manager.py | 4 ++++
lightllm/server/router/manager.py | 8 ++++----
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 945184340..6f15f46ab 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -317,7 +317,11 @@ async def generate(
# 用于等待 pd_master 下发的交换信息
pd_event: asyncio.Event = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:
+ group_request_id = None
if isinstance(prompt, str):
+ # Guard against extremely long string prompts that might stall the tokenizer
+ # or cause excessive memory usage before tokenization.
+ # 8 characters per token is a conservative heuristic (avg is ~4).
max_prompt_chars = self.max_req_total_len * 8
if len(prompt) > max_prompt_chars:
raise ValueError(
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index 15e05afb7..ff08963a6 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -348,16 +348,16 @@ def _filter_reqs_from_running_batch(self):
# Settle any output-token tail produced after the last window boundary,
# so windowed TPS does not lose the req's last tokens.
self.status_reporter.discard_req(req)
- # Aborted/disconnected requests can leave a partial output_len that
- # would bias the EMA toward shorter generations; skip them.
- if req.is_aborted:
- continue
self.status_reporter.on_request_completed(
input_len=req.input_len,
output_len=req.shm_cur_output_len,
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
+ # Aborted/disconnected requests can leave a partial output_len that
+ # would bias the EMA toward shorter generations; skip them.
+ if req.is_aborted:
+ continue
self.router_statics.update(req.candetoken_out_len)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
From 0405484e7b001a78c4f54aa34a00c097db3765f9 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Sat, 9 May 2026 13:40:57 +0800
Subject: [PATCH 09/14] fix(detokenization): keep req registration failures
fatal
---
lightllm/server/detokenization/manager.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py
index ae7103700..90181d13a 100644
--- a/lightllm/server/detokenization/manager.py
+++ b/lightllm/server/detokenization/manager.py
@@ -76,10 +76,7 @@ def handle_loop(self):
for _ in range(recv_max_count):
recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
assert isinstance(recv_obj, GroupReqIndexes)
- try:
- self._add_new_group_req_index(recv_obj=recv_obj)
- except Exception:
- logger.exception("add new group req index has exception")
+ self._add_new_group_req_index(recv_obj=recv_obj)
# 当队列中存在较多的请求时,将一次接受的数量上调
recv_max_count = min(int(recv_max_count * 1.3), 256)
From 42c754b31399bb31d423a2d166a7f3853ace3f3a Mon Sep 17 00:00:00 2001
From: sufubao
Date: Sat, 9 May 2026 13:58:36 +0800
Subject: [PATCH 10/14] refactor(router): improve status log format
---
lightllm/server/router/stats.py | 24 ++++++++++++++----------
1 file changed, 14 insertions(+), 10 deletions(-)
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index f6db924b5..dac37b67b 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -122,20 +122,24 @@ def maybe_print(
kv_pct = avg_kv_used * 100
kv_pct_no_cache = avg_kv_used_no_cache * 100
+ log_parts = [
+ f"router_status(window={elapsed:.1f}s)",
+ f"throughput(total={total_tps:.1f},input={input_tps:.1f},output={output_tps:.1f})",
+ f"req(running={running},waiting={queued},paused={paused_req_num})",
+ f"kv(used={kv_pct_no_cache:.1f}%)",
+ f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)",
+ ]
+
# Avg MTP accepted length (only shown when MTP is active)
- mtp_suffix = ""
if self.global_mtp_accepted_total > 0:
decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total
avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1)
- mtp_suffix = f", MTP {avg_mtp_len:.2f}"
-
- self.status_logger.info(
- f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), "
- f"REQ {running}run, {queued}wait, {paused_req_num}pause, "
- f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), "
- f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)"
- f"{mtp_suffix}"
- )
+ log_parts.append(
+ f"mtp(avg_tokens_per_step={avg_mtp_len:.2f},"
+ f"accepted={self.global_mtp_accepted_total},output={self.global_mtp_output_total})"
+ )
+
+ self.status_logger.info(" | ".join(log_parts))
# Reset windowed counters
self.prompt_tokens = 0
From 9cd7c4593e49dd6a46b52d034e44a81f346fbee1 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Sat, 9 May 2026 14:09:17 +0800
Subject: [PATCH 11/14] fix(router): remove unused status log variable
---
lightllm/server/router/stats.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index dac37b67b..8aac2b036 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -119,7 +119,6 @@ def maybe_print(
(self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
)
- kv_pct = avg_kv_used * 100
kv_pct_no_cache = avg_kv_used_no_cache * 100
log_parts = [
From 5a93c5334774f69f05ee87cd1c08b108a254ab4b Mon Sep 17 00:00:00 2001
From: sufubao
Date: Sat, 9 May 2026 14:47:17 +0800
Subject: [PATCH 12/14] feat(router): add debug status details
---
lightllm/server/router/stats.py | 53 +++++++++++++++++++++++++
lightllm/server/visualserver/manager.py | 2 +-
2 files changed, 54 insertions(+), 1 deletion(-)
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index 8aac2b036..c414a4386 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -1,5 +1,6 @@
import time
import logging
+import subprocess
from typing import Dict
from lightllm.server.core.objs import StartArgs
from lightllm.utils.log_utils import init_system_status_logger
@@ -59,6 +60,44 @@ def on_request_completed(self, input_len: int, output_len: int, cache_len: int,
self.global_mtp_output_total += output_len
self.global_mtp_accepted_total += mtp_accepted
+ def _get_gpu_status_for_debug(self) -> str:
+ try:
+ result = subprocess.run(
+ [
+ "nvidia-smi",
+ "--query-gpu=index,utilization.gpu,memory.used,memory.total",
+ "--format=csv,noheader,nounits",
+ ],
+ check=True,
+ capture_output=True,
+ text=True,
+ timeout=2,
+ )
+ except (OSError, subprocess.SubprocessError) as e:
+ return f"gpu=unavailable({e.__class__.__name__})"
+
+ gpu_infos = []
+ for line in result.stdout.splitlines():
+ parts = [part.strip() for part in line.split(",")]
+ if len(parts) != 4:
+ continue
+ gpu_index, util, mem_used, mem_total = parts
+ try:
+ mem_used_mb = float(mem_used)
+ mem_total_mb = float(mem_total)
+ mem_ratio = mem_used_mb / mem_total_mb * 100 if mem_total_mb > 0 else 0.0
+ mem_used_gb = mem_used_mb / 1024
+ mem_total_gb = mem_total_mb / 1024
+ gpu_infos.append(
+ f"{gpu_index}(util={float(util):.0f}%,mem={mem_ratio:.1f}%,"
+ f"used={mem_used_gb:.1f}GiB/{mem_total_gb:.1f}GiB)"
+ )
+ except ValueError:
+ continue
+ if not gpu_infos:
+ return "gpu=unavailable(empty)"
+ return "gpu=[" + ";".join(gpu_infos) + "]"
+
def maybe_print(
self,
running_batch,
@@ -119,6 +158,7 @@ def maybe_print(
(self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
)
+ kv_pct = avg_kv_used * 100
kv_pct_no_cache = avg_kv_used_no_cache * 100
log_parts = [
@@ -139,6 +179,19 @@ def maybe_print(
)
self.status_logger.info(" | ".join(log_parts))
+ if logger.isEnabledFor(logging.DEBUG):
+ kv_unrefed_prefix_cache_pct = max(0.0, kv_pct - kv_pct_no_cache)
+ debug_parts = [
+ "router_status_debug",
+ f"kv_physical={kv_pct:.1f}%",
+ f"kv_unrefed_prefix_cache={kv_unrefed_prefix_cache_pct:.1f}%",
+ f"throughput_tokens(input={self.prompt_tokens},output={self.output_tokens})",
+ f"gpu_cache_tokens(window={self.window_cache_total}/{self.window_input_total},"
+ f"global={self.global_cache_total}/{self.global_input_total})",
+ f"tracked_output_reqs={len(self._req_last_output_len)}",
+ self._get_gpu_status_for_debug(),
+ ]
+ logger.debug(" | ".join(debug_parts))
# Reset windowed counters
self.prompt_tokens = 0
diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py
index 1dffdaf68..ef8637853 100644
--- a/lightllm/server/visualserver/manager.py
+++ b/lightllm/server/visualserver/manager.py
@@ -176,7 +176,7 @@ async def loop_for_netio_req(self):
while True:
recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj)
if isinstance(recv_req, GroupReqIndexes):
- logger.info(
+ logger.debug(
f"visual recv req id {recv_req.group_req_id} "
f"img count {len(recv_req.multimodal_params.images)}"
)
From 1766c80cc32e3afbbf1a69bae0e5b3595ab8a55c Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 17 Jun 2026 11:05:09 +0800
Subject: [PATCH 13/14] fix(server): restore config access logs and prompt
guard
---
lightllm/server/access_log.py | 31 ++++++++++++++++
lightllm/server/api_http.py | 36 ++-----------------
lightllm/server/config_server/api_http.py | 2 ++
lightllm/server/httpserver/manager.py | 23 +++---------
lightllm/server/httpserver/prompt_utils.py | 10 ++++++
.../httpserver_for_pd_master/manager.py | 2 ++
.../server/config_server/test_access_log.py | 14 ++++++++
.../httpserver/test_prompt_length_guard.py | 33 +++++++++++++++++
8 files changed, 98 insertions(+), 53 deletions(-)
create mode 100644 lightllm/server/access_log.py
create mode 100644 lightllm/server/httpserver/prompt_utils.py
create mode 100644 unit_tests/server/config_server/test_access_log.py
create mode 100644 unit_tests/server/httpserver/test_prompt_length_guard.py
diff --git a/lightllm/server/access_log.py b/lightllm/server/access_log.py
new file mode 100644
index 000000000..7365259c3
--- /dev/null
+++ b/lightllm/server/access_log.py
@@ -0,0 +1,31 @@
+_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
+_ACCESS_LOG_RESET = "\033[0m"
+
+
+class _AccessLogMiddleware:
+ def __init__(self, app, logger):
+ self.app = app
+ self.logger = logger
+
+ async def __call__(self, scope, receive, send):
+ if scope["type"] not in ("http", "websocket"):
+ await self.app(scope, receive, send)
+ return
+
+ status_holder = {"status": 0}
+
+ async def send_wrapper(message):
+ if message["type"] == "http.response.start":
+ status_holder["status"] = message["status"]
+ await send(message)
+
+ try:
+ await self.app(scope, receive, send_wrapper)
+ finally:
+ if scope["type"] == "http":
+ status = status_holder["status"]
+ msg = f"{scope['method']} {scope['path']} {status}"
+ color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "")
+ if color:
+ msg = color + msg + _ACCESS_LOG_RESET
+ self.logger.info(msg)
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index e127f0931..c11778569 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -50,6 +50,7 @@
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.access_log import _AccessLogMiddleware
from dataclasses import dataclass
from .api_openai import chat_completions_impl, completions_impl
@@ -115,40 +116,7 @@ def set_args(self, args: StartArgs):
app = FastAPI()
g_objs.app = app
-
-_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
-_ACCESS_LOG_RESET = "\033[0m"
-
-
-class _AccessLogMiddleware:
- def __init__(self, app):
- self.app = app
-
- async def __call__(self, scope, receive, send):
- if scope["type"] not in ("http", "websocket"):
- await self.app(scope, receive, send)
- return
-
- status_holder = {"status": 0}
-
- async def send_wrapper(message):
- if message["type"] == "http.response.start":
- status_holder["status"] = message["status"]
- await send(message)
-
- try:
- await self.app(scope, receive, send_wrapper)
- finally:
- if scope["type"] == "http":
- status = status_holder["status"]
- msg = f"{scope['method']} {scope['path']} {status}"
- color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "")
- if color:
- msg = color + msg + _ACCESS_LOG_RESET
- logger.info(msg)
-
-
-app.add_middleware(_AccessLogMiddleware)
+app.add_middleware(_AccessLogMiddleware, logger=logger)
def create_error_response(
diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py
index 3ce39bb6e..8b4e234e0 100644
--- a/lightllm/server/config_server/api_http.py
+++ b/lightllm/server/config_server/api_http.py
@@ -9,6 +9,7 @@
from typing import Dict, List
from fastapi.responses import JSONResponse
from lightllm.utils.log_utils import init_logger
+from lightllm.server.access_log import _AccessLogMiddleware
from lightllm.server.visualserver.objs import VIT_Obj
from ..pd_io_struct import PD_Master_Obj
from .nccl_tcp_store import start_tcp_store_server
@@ -18,6 +19,7 @@
logger = init_logger(__name__)
app = FastAPI()
+app.add_middleware(_AccessLogMiddleware, logger=logger)
registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
registered_visual_server_objs: Dict[str, VIT_Obj] = {}
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 6f15f46ab..fda0ddd40 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -22,6 +22,7 @@
from ..multimodal_params import AudioItem, MultimodalParams, ImageItem
from ..req_id_generator import ReqIDGenerator
from .async_queue import AsyncQueue
+from .prompt_utils import validate_prompt_text_length
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
@@ -252,6 +253,7 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):
kwargs = {} if kwargs is None else kwargs
+ validate_prompt_text_length(prompt, self.max_req_total_len)
prompt_ids = self.tokenizer.encode(prompt, None, **kwargs)
image_tokens = 0
img_count = 0
@@ -318,16 +320,7 @@ async def generate(
pd_event: asyncio.Event = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:
group_request_id = None
- if isinstance(prompt, str):
- # Guard against extremely long string prompts that might stall the tokenizer
- # or cause excessive memory usage before tokenization.
- # 8 characters per token is a conservative heuristic (avg is ~4).
- max_prompt_chars = self.max_req_total_len * 8
- if len(prompt) > max_prompt_chars:
- raise ValueError(
- f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
- f"the request is rejected before tokenization."
- )
+ validate_prompt_text_length(prompt, self.max_req_total_len)
start_time = time.time()
request_headers = request.headers if request is not None else {}
@@ -534,15 +527,7 @@ async def _encode(
self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams
):
if isinstance(prompt, str):
- # pre-verify prompt length
- # The average character length per token is always less than 8
- # TODO: automatically calculate the average character length per token
- max_prompt_chars = self.max_req_total_len * 8
- if len(prompt) > max_prompt_chars:
- raise ValueError(
- f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
- f"the request is rejected before tokenization."
- )
+ validate_prompt_text_length(prompt, self.max_req_total_len)
if self.enable_multimodal:
assert (
len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
diff --git a/lightllm/server/httpserver/prompt_utils.py b/lightllm/server/httpserver/prompt_utils.py
new file mode 100644
index 000000000..98f198c23
--- /dev/null
+++ b/lightllm/server/httpserver/prompt_utils.py
@@ -0,0 +1,10 @@
+def validate_prompt_text_length(prompt, max_req_total_len):
+ if not isinstance(prompt, str):
+ return
+
+ max_prompt_chars = max_req_total_len * 8
+ if len(prompt) > max_prompt_chars:
+ raise ValueError(
+ f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
+ f"the request is rejected before tokenization."
+ )
diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py
index 3d73dee47..796a7a7ed 100644
--- a/lightllm/server/httpserver_for_pd_master/manager.py
+++ b/lightllm/server/httpserver_for_pd_master/manager.py
@@ -19,6 +19,7 @@
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
+from lightllm.server.httpserver.prompt_utils import validate_prompt_text_length
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.utils.envs_utils import get_pd_split_max_new_tokens
from .pd_selector import create_selector
@@ -73,6 +74,7 @@ async def update_req_status(self, upkv_status: PDUpKVStatus):
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):
kwargs = {} if kwargs is None else kwargs
+ validate_prompt_text_length(prompt, self.max_req_total_len)
prompt_ids = self.tokenizer.encode(prompt, None, **kwargs)
image_tokens = 0
img_count = 0
diff --git a/unit_tests/server/config_server/test_access_log.py b/unit_tests/server/config_server/test_access_log.py
new file mode 100644
index 000000000..497823980
--- /dev/null
+++ b/unit_tests/server/config_server/test_access_log.py
@@ -0,0 +1,14 @@
+from fastapi.testclient import TestClient
+
+
+def test_config_server_emits_access_log(monkeypatch):
+ from lightllm.server.config_server import api_http
+
+ messages = []
+ monkeypatch.setattr(api_http.logger, "info", lambda msg, *args, **kwargs: messages.append(str(msg)))
+
+ with TestClient(api_http.app) as client:
+ response = client.get("/health")
+
+ assert response.status_code == 200
+ assert any("GET /health 200" in message for message in messages)
diff --git a/unit_tests/server/httpserver/test_prompt_length_guard.py b/unit_tests/server/httpserver/test_prompt_length_guard.py
new file mode 100644
index 000000000..f03a59a32
--- /dev/null
+++ b/unit_tests/server/httpserver/test_prompt_length_guard.py
@@ -0,0 +1,33 @@
+from types import SimpleNamespace
+
+import pytest
+
+from lightllm.server.httpserver.manager import HttpServerManager
+from lightllm.server.httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
+
+
+class TokenizerMustNotRun:
+ vocab_size = 32000
+
+ def encode(self, *args, **kwargs):
+ raise AssertionError("tokenizer should not run for oversized text prompts")
+
+
+def _fake_manager():
+ return SimpleNamespace(
+ max_req_total_len=4,
+ tokenizer=TokenizerMustNotRun(),
+ args=SimpleNamespace(max_image_token_count=1024),
+ )
+
+
+def _empty_multimodal_params():
+ return SimpleNamespace(images=[], audios=[])
+
+
+@pytest.mark.parametrize("manager_cls", [HttpServerManager, HttpServerManagerForPDMaster])
+def test_tokens_rejects_oversized_text_prompt_before_tokenization(manager_cls):
+ prompt = "x" * 33
+
+ with pytest.raises(ValueError, match="prompt text length 33 exceeds the character limit 32"):
+ manager_cls.tokens(_fake_manager(), prompt, _empty_multimodal_params(), SimpleNamespace())
From 4360f5ccc85afe81023a8b242fff172bf4602a65 Mon Sep 17 00:00:00 2001
From: sufubao
Date: Wed, 17 Jun 2026 22:32:01 +0800
Subject: [PATCH 14/14] chore(router): trim redundant logging comments
---
lightllm/server/router/manager.py | 5 -----
lightllm/server/router/stats.py | 8 --------
lightllm/utils/log_utils.py | 1 -
3 files changed, 14 deletions(-)
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index ff08963a6..e9d90e864 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -233,8 +233,6 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
- # Output-token counting is done in bulk at the print-window boundary
- # inside SystemStatusReporter.maybe_print, so the router tick stays cheap.
if counter_count % 100 == 0:
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
@@ -341,12 +339,9 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):
def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
- # Capture finished req stats before filtering
for req in self.running_batch.reqs:
if not req.shm_infer_released:
continue
- # Settle any output-token tail produced after the last window boundary,
- # so windowed TPS does not lose the req's last tokens.
self.status_reporter.discard_req(req)
self.status_reporter.on_request_completed(
input_len=req.input_len,
diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py
index c414a4386..94556f21c 100644
--- a/lightllm/server/router/stats.py
+++ b/lightllm/server/router/stats.py
@@ -18,16 +18,13 @@ def __init__(self, args, max_total_token_num, dp_size_in_node):
self.dp_size_in_node = dp_size_in_node
self.status_logger = init_system_status_logger("router")
- # Accumulation counters (reset each interval)
self.last_print_time = time.time()
self.prompt_tokens = 0
self.output_tokens = 0
- # Windowed counters for cache hit (reset each interval)
self.window_input_total = 0
self.window_cache_total = 0
- # Global counters (never reset, for lifetime stats)
self.global_input_total = 0
self.global_cache_total = 0
self.global_mtp_output_total = 0
@@ -132,7 +129,6 @@ def maybe_print(
running = len(running_batch.reqs) if running_batch else 0
queued = req_queue.get_wait_req_num()
- # Memory utilization (average across dp)
# kv_used: physical KV memory usage (includes prefix cache tree occupancy)
# kv_used_no_cache: effective usage excluding unrefed prefix cache tokens
kv_used_list = []
@@ -149,11 +145,9 @@ def maybe_print(
avg_kv_used = sum(kv_used_list) / len(kv_used_list)
avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list)
- # Windowed prefix cache hit rate (this interval only)
window_cache_hit_rate = (
(self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0
)
- # Global prefix cache hit rate (lifetime)
global_cache_hit_rate = (
(self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
)
@@ -169,7 +163,6 @@ def maybe_print(
f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)",
]
- # Avg MTP accepted length (only shown when MTP is active)
if self.global_mtp_accepted_total > 0:
decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total
avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1)
@@ -193,7 +186,6 @@ def maybe_print(
]
logger.debug(" | ".join(debug_parts))
- # Reset windowed counters
self.prompt_tokens = 0
self.output_tokens = 0
self.window_input_total = 0
diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py
index c3057d18f..6df36c3ae 100644
--- a/lightllm/utils/log_utils.py
+++ b/lightllm/utils/log_utils.py
@@ -15,7 +15,6 @@
_LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0)
_LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None)
-# ANSI color codes
_RESET = "\033[0m"
_LEVEL_COLORS = {
logging.DEBUG: "\033[36m", # cyan