diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..7a24fd17fb 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -4,6 +4,7 @@ import bisect import triton from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager @@ -197,7 +198,11 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -252,7 +257,13 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..3cbc5dc0f3 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key): cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key)) if os.path.exists(cache_file): - logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") + logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) return True @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size): option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS, ) ) - logger.info(f"Saved configs for {self.kernel_name} - {_static_key}") + logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}") - logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished") + logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished") def _mutate_args_clone(self, args, kwargs): origin_list = [] diff --git a/lightllm/server/access_log.py b/lightllm/server/access_log.py new file mode 100644 index 0000000000..7365259c37 --- /dev/null +++ b/lightllm/server/access_log.py @@ -0,0 +1,31 @@ +_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} +_ACCESS_LOG_RESET = "\033[0m" + + +class _AccessLogMiddleware: + def __init__(self, app, logger): + self.app = app + self.logger = logger + + async def __call__(self, scope, receive, send): + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + status_holder = {"status": 0} + + async def send_wrapper(message): + if message["type"] == "http.response.start": + status_holder["status"] = message["status"] + await send(message) + + try: + await self.app(scope, receive, send_wrapper) + finally: + if scope["type"] == "http": + status = status_holder["status"] + msg = f"{scope['method']} {scope['path']} {status}" + color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "") + if color: + msg = color + msg + _ACCESS_LOG_RESET + self.logger.info(msg) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 270e2a8cfd..c117785698 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -50,6 +50,7 @@ from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.access_log import _AccessLogMiddleware from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -115,41 +116,7 @@ def set_args(self, args: StartArgs): app = FastAPI() g_objs.app = app - -_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} -_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} -_ACCESS_LOG_RESET = "\033[0m" - - -class _AccessLogMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - if scope["type"] not in ("http", "websocket"): - await self.app(scope, receive, send) - return - - status_holder = {"status": 0} - - async def send_wrapper(message): - if message["type"] == "http.response.start": - status_holder["status"] = message["status"] - await send(message) - - try: - await self.app(scope, receive, send_wrapper) - finally: - if scope["type"] == "http": - status = status_holder["status"] - msg = f"{scope['method']} {scope['path']} {status}" - color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "") - if color: - msg = color + msg + _ACCESS_LOG_RESET - logger.info(msg) - - -app.add_middleware(_AccessLogMiddleware) +app.add_middleware(_AccessLogMiddleware, logger=logger) def create_error_response( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cf431d650..d191ee3940 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -495,8 +495,6 @@ def normal_or_p_d_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -566,8 +564,6 @@ def pd_master_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -662,8 +658,6 @@ def config_server_start(args): f"{args.config_server_host}:{args.config_server_port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.config_server.api_http:app", diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 3ce39bb6e6..8b4e234e02 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.access_log import _AccessLogMiddleware from lightllm.server.visualserver.objs import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server @@ -18,6 +19,7 @@ logger = init_logger(__name__) app = FastAPI() +app.add_middleware(_AccessLogMiddleware, logger=logger) registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} registered_visual_server_objs: Dict[str, VIT_Obj] = {} diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 8c213914c7..90181d13af 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): req.link_prompt_ids_shm_array() req.link_logprobs_shm_array() - logger.info( + logger.debug( f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s" ) @@ -160,7 +160,7 @@ def remove_finished_reqs(self): for decode_req in finished_reqs: decode_req.req.can_released_mark = True - logger.info(f"detoken release req id {decode_req.req.request_id}") + logger.debug(f"detoken release req id {decode_req.req.request_id}") self.shm_req_manager.put_back_req_obj(decode_req.req) self.req_id_to_out.pop(decode_req.request_id, None) return diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index dfcb2f8d9e..093c0c5681 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1758,6 +1758,10 @@ def __init__(self): self.parameter_regex = re.compile( r"|(?=)|$)", re.DOTALL ) + self.parameter_stream_regex = re.compile( + r"]*)>(.*?)(|(?=)|$)", re.DOTALL + ) + self.function_start_regex = re.compile(r"]*)>", re.DOTALL) self._normal_text_buffer = "" def has_tool_call(self, text: str) -> bool: @@ -1782,8 +1786,7 @@ def _convert_param_value(self, value: str, param_name: str, param_config: Dict, if param_name not in param_config: return value - prop = param_config.get(param_name, {}) - param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + param_type = self._get_qwen3_param_type(param_name, param_config) if param_type in ("string", "str", "enum"): return value @@ -1833,12 +1836,7 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional except ValueError: continue param_name = match[:idx].strip() - param_value = match[idx + 1 :] - # Strip leading/trailing newlines from value - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] + param_value = self._strip_value_newlines(match[idx + 1 :]) param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) @@ -1848,47 +1846,120 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional parameters=json.dumps(param_dict, ensure_ascii=False), ) - def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools: List[Tool]) -> Optional[str]: - """Build the current argument JSON from a partial XML tool-call body.""" - param_matches = self.parameter_regex.findall(partial_body) - if not param_matches: - return None + def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str: + prop = param_config.get(param_name, {}) + return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + def _strip_value_newlines(self, value: str) -> str: + """Strip the single leading/trailing newline the Qwen3 template wraps each value in.""" + if value.startswith("\n"): + value = value[1:] + if value.endswith("\n"): + value = value[:-1] + return value + def _strip_partial_xml_suffix(self, value: str) -> str: + for token in ("", "", self.eot_token): + max_len = min(len(value), len(token) - 1) + for suffix_len in range(max_len, 0, -1): + if token.startswith(value[-suffix_len:]): + return value[:-suffix_len] + return value + + def _build_streaming_arguments_json( + self, + func_name: str, + partial_body: str, + tools: List[Tool], + close_object: bool = False, + ) -> Optional[str]: + """Build a monotonic JSON arguments prefix for XML tool-call streaming. + + The result is always a byte-exact prefix of json.dumps(final_arguments) so the + serving layer (api_openai.py) can reconcile the streamed args at stream end. + String values stream character-by-character (a string prefix stays a prefix); + non-string values are only emitted once their arrives, because a + partial number/array/bool is not guaranteed to be a prefix of its json.dumps form. + """ param_config = self._get_param_config(func_name, tools) - param_dict = {} - has_visible_value = False + parts = ["{"] + has_param = False - for match in param_matches: - try: - idx = match.index(">") - except ValueError: + for match in self.parameter_stream_regex.finditer(partial_body): + param_name = match.group(1).strip() + if not param_name: continue - param_name = match[:idx].strip() - param_value = match[idx + 1 :] - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - if param_value.strip(): - has_visible_value = True - elif ( - f"" in partial_body - and f"{param_value}" in partial_body - ): - # Closed empty-string parameter. We can safely emit it. - has_visible_value = True + # The value is complete only when an explicit closed it, or a + # sibling follows. Otherwise it is still streaming. + # (We can't key off match.end()==len: `$` matches before a trailing newline, + # and the template wraps every value in one, which would look "complete".) + rest = partial_body[match.end() :] + value_open = ( + match.group(3) != "" + and not rest.startswith("") + ) + + if has_param: + parts.append(", ") + parts.append(json.dumps(param_name, ensure_ascii=False)) + parts.append(": ") + has_param = True + + param_type = self._get_qwen3_param_type(param_name, param_config) + is_string = param_type in ("string", "str", "enum") + + if value_open: + # In-progress (and therefore last) parameter. + if is_string: + value = self._strip_value_newlines(self._strip_partial_xml_suffix(match.group(2))) + # Drop the closing quote so the stream stays an extendable prefix. + parts.append(json.dumps(value, ensure_ascii=False)[:-1]) + # Non-string values cannot be emitted as a safe partial prefix, so stop + # after the key and wait for the value to close. + return "".join(parts) + + value = self._strip_value_newlines(match.group(2)) + if is_string: + parts.append(json.dumps(value, ensure_ascii=False)) else: - # Parameter tag is present but its value has not started streaming yet. - continue + converted = self._convert_param_value(value, param_name, param_config, func_name) + parts.append(json.dumps(converted, ensure_ascii=False)) - param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + if not has_param: + return "{}" if close_object else None - if not param_dict and not has_visible_value: - return None + if close_object: + parts.append("}") - return json.dumps(param_dict, ensure_ascii=False) + return "".join(parts) + + def _ensure_qwen3_stream_state(self, tool_index: int) -> None: + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + def _append_qwen3_arguments_delta(self, calls: List[ToolCallItem], tool_index: int, current_args_json: str) -> None: + sent_args = self.streamed_args_for_tool[tool_index] + if not current_args_json.startswith(sent_args): + logger.warning( + "Qwen3-Coder streaming arguments are not monotonic for tool index %s; skip delta.", + tool_index, + ) + return + + argument_diff = current_args_json[len(sent_args) :] + if argument_diff: + calls.append( + ToolCallItem( + tool_index=tool_index, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[tool_index] += argument_diff def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: idx = text.find(self.bot_token) @@ -1941,23 +2012,100 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami self._buffer = current_text[tool_call_start:] current_text = self._buffer + if self.current_tool_id == -1: + self.current_tool_id = 0 + + self._ensure_qwen3_stream_state(self.current_tool_id) + + function_match = self.function_start_regex.search(current_text) + if not function_match: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + func_name = function_match.group(1).strip() eot_pos = current_text.find(self.eot_token) + func_defined = func_name in self._tool_indices + + # Undefined function whose block has not finished yet: wait for more text + # (the block may also contain a valid function we shouldn't drop). + if not func_defined and eot_pos == -1: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + if func_defined: + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + # The function body is complete once we hit either or the + # enclosing ; treating eot as an implicit close lets us emit + # the closing '}' inside the same args delta (so the serving layer's + # stop-time reconciliation sees one delta, not a separate trailing '}'). + function_close_pos = current_text.find("", function_match.end()) + if function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos): + partial_end = function_close_pos + close_object = True + elif eot_pos != -1: + partial_end = eot_pos + close_object = True + else: + partial_end = len(current_text) + close_object = False + partial_body = current_text[function_match.end() : partial_end] + current_args_json = self._build_streaming_arguments_json( + func_name, + partial_body, + tools, + close_object=close_object, + ) + if current_args_json: + self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json) + if eot_pos == -1: return StreamingParseResult(normal_text=normal_text, calls=calls) complete_block = current_text[: eot_pos + len(self.eot_token)] func_matches = self.function_regex.findall(complete_block) - if self.current_tool_id == -1: - self.current_tool_id = 0 - + # Flush every completed function in the block. _parse_function_call returns + # None for undefined ones, so they are skipped without advancing the index. for match in func_matches: func_str = match[0] if match[0] else match[1] item = self._parse_function_call(func_str, tools) - if item: - item.tool_index = self.current_tool_id - calls.append(item) - self.current_tool_id += 1 + if not item: + continue + completed_tool_id = self.current_tool_id + self._ensure_qwen3_stream_state(completed_tool_id) + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=completed_tool_id, + name=item.name, + parameters="", + ) + ) + self.current_tool_name_sent = True + try: + parsed_args = json.loads(item.parameters) + except json.JSONDecodeError: + parsed_args = {} + self.prev_tool_call_arr[completed_tool_id] = { + "name": item.name, + "arguments": parsed_args, + } + sent_args = self.streamed_args_for_tool[completed_tool_id] + if item.parameters.startswith(sent_args): + self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters) + self.current_tool_id += 1 + self.current_tool_name_sent = False self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0f1b873111..fda0ddd407 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -22,6 +22,7 @@ from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from .async_queue import AsyncQueue +from .prompt_utils import validate_prompt_text_length from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE @@ -252,6 +253,7 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs + validate_prompt_text_length(prompt, self.max_req_total_len) prompt_ids = self.tokenizer.encode(prompt, None, **kwargs) image_tokens = 0 img_count = 0 @@ -317,6 +319,8 @@ async def generate( # 用于等待 pd_master 下发的交换信息 pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + group_request_id = None + validate_prompt_text_length(prompt, self.max_req_total_len) start_time = time.time() request_headers = request.headers if request is not None else {} @@ -467,6 +471,12 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status + except ValueError as e: + logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}") + if group_request_id not in self.req_id_to_out_inf: + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e except (ClientDisconnected, Exception) as e: logger.warning(f"group_request_id: {group_request_id} has exception {str(e)}") @@ -506,7 +516,7 @@ async def _log_req_header(self, request_headers, group_request_id: int): x_session_id = request_headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " @@ -517,15 +527,7 @@ async def _encode( self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams ): if isinstance(prompt, str): - # pre-verify prompt length - # The average character length per token is always less than 8 - # TODO: automatically calculate the average character length per token - max_prompt_chars = self.max_req_total_len * 8 - if len(prompt) > max_prompt_chars: - raise ValueError( - f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " - f"the request is rejected before tokenization." - ) + validate_prompt_text_length(prompt, self.max_req_total_len) if self.enable_multimodal: assert ( len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity @@ -744,7 +746,7 @@ async def _wait_to_token_package( (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_start_time} " f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " @@ -856,8 +858,8 @@ async def recycle_resource_loop(self): if req_status is None: continue - logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + logger.debug( + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver/prompt_utils.py b/lightllm/server/httpserver/prompt_utils.py new file mode 100644 index 0000000000..98f198c23c --- /dev/null +++ b/lightllm/server/httpserver/prompt_utils.py @@ -0,0 +1,10 @@ +def validate_prompt_text_length(prompt, max_req_total_len): + if not isinstance(prompt, str): + return + + max_prompt_chars = max_req_total_len * 8 + if len(prompt) > max_prompt_chars: + raise ValueError( + f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " + f"the request is rejected before tokenization." + ) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 104da9f26e..796a7a7ed1 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -19,6 +19,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue +from lightllm.server.httpserver.prompt_utils import validate_prompt_text_length from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError from lightllm.utils.envs_utils import get_pd_split_max_new_tokens from .pd_selector import create_selector @@ -73,6 +74,7 @@ async def update_req_status(self, upkv_status: PDUpKVStatus): def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): kwargs = {} if kwargs is None else kwargs + validate_prompt_text_length(prompt, self.max_req_total_len) prompt_ids = self.tokenizer.encode(prompt, None, **kwargs) image_tokens = 0 img_count = 0 @@ -197,7 +199,7 @@ async def _log_req_header(self, request: Request, group_request_id: int): x_request_id = request.headers.get("X-Request-Id", "") x_session_id = request.headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 24d0b9b824..34902d812a 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union from lightllm.server.core.objs import ShmReqManager, Req from lightllm.utils.log_utils import init_logger -from .stats import RouterStatics logger = init_logger(__name__) @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]: all_dp_req_num[req.sample_params.suggested_dp_index] += 1 return all_dp_req_num - def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics): + def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids = [] for req in self.reqs: if req.shm_infer_released: - logger.info(f"router release req id {req.request_id}") - if not req.is_aborted: - router_statics.update(req.candetoken_out_len) - + logger.debug(f"router release req id {req.request_id}") shm_req_manager.put_back_req_obj(req) req = None else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index dfb8866601..e9d90e8643 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -16,6 +16,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from .stats import SystemStatusReporter from lightllm.server.core.objs.io_objs import ( GroupReqIndexes, AbortedReqCmd, @@ -25,7 +26,7 @@ from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer -from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.log_utils import init_logger from lightllm.utils.profiler import ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient @@ -68,6 +69,7 @@ def __init__(self, args: StartArgs): self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.status_reporter = None # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -197,6 +199,11 @@ async def wait_to_model_ready(self): ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") + self.status_reporter = SystemStatusReporter( + args=self.args, + max_total_token_num=self.max_total_token_num, + dp_size_in_node=self.dp_size_in_node, + ) if self.args.run_mode == "prefill": from lightllm.server.router.model_infer.mode_backend.pd.prefill_node_impl import ( @@ -227,23 +234,6 @@ async def loop_for_fwd( counter_count += 1 if self.running_batch is not None: if counter_count % 100 == 0: - for dp_index in range(self.dp_size_in_node): - token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num - token_ratio2 = ( - self.max_total_token_num - - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - ) / self.max_total_token_num - d_i = dp_index - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) - paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - logger.debug( - f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" - f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" - f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" - ) - logger.debug(self.router_statics.log_str()) self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) @@ -265,11 +255,15 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) - # 60s print once - if log_time_ready("token_load_info", 60): - for dp_i in range(self.dp_size_in_node): - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") + + self.status_reporter.maybe_print( + running_batch=self.running_batch, + req_queue=self.req_queue, + read_only_statics_mem_manager=self.read_only_statics_mem_manager, + paused_req_num=self._get_paused_req_num(), + radix_cache_client=self.radix_cache_client, + disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache, + ) await asyncio.sleep(self._get_schedule_time_interval()) @@ -300,6 +294,7 @@ async def _step(self): async def _add_batch(self, batch: Batch): # 添加新请求 + self.status_reporter.count_prompt_tokens(batch.input_tokens()) reqs = [r.to_router_rpc_obj() for r in batch.reqs] while not self.shm_reqs_io_buffer.is_empty(): await asyncio.sleep(0.001) @@ -344,7 +339,22 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics) + for req in self.running_batch.reqs: + if not req.shm_infer_released: + continue + self.status_reporter.discard_req(req) + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=req.shm_cur_output_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) + # Aborted/disconnected requests can leave a partial output_len that + # would bias the EMA toward shorter generations; skip them. + if req.is_aborted: + continue + self.router_statics.update(req.candetoken_out_len) + self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None return @@ -416,7 +426,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req._router_stop_str_matched = False req_group.append(req) - logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") + logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s") self.req_queue.extend(req_group) self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return @@ -431,6 +441,8 @@ def _generate_new_batch(self): logger.info(f"generate new batch, {new_batch.simple_log()}") self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) + if self.schedule_new_batch is not None: + logger.debug(f"gen new batch, {self.schedule_new_batch.simple_log()}") return def _multinode_tp_generate_new_batch(self): diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index b715c5bcb3..94556f21cb 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,7 +1,196 @@ -from lightllm.utils.log_utils import init_logger +import time +import logging +import subprocess +from typing import Dict from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_system_status_logger -logger = init_logger(__name__) +logger = logging.getLogger(__name__) + + +class SystemStatusReporter: + def __init__(self, args, max_total_token_num, dp_size_in_node): + self.enabled = not args.disable_log_stats + self.interval = max(5, args.log_stats_interval) + if args.log_stats_interval < 5: + logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s") + self.max_total_token_num = max_total_token_num + self.dp_size_in_node = dp_size_in_node + self.status_logger = init_system_status_logger("router") + + self.last_print_time = time.time() + self.prompt_tokens = 0 + self.output_tokens = 0 + + self.window_input_total = 0 + self.window_cache_total = 0 + + self.global_input_total = 0 + self.global_cache_total = 0 + self.global_mtp_output_total = 0 + self.global_mtp_accepted_total = 0 + + # Per-req shm_cur_output_len snapshot at the previous window boundary, + # used to compute the windowed output-token count without per-tick scans. + self._req_last_output_len: Dict[int, int] = {} + + def count_prompt_tokens(self, num_tokens: int): + if self.enabled: + self.prompt_tokens += num_tokens + + def discard_req(self, req): + """Settle a finished/aborted req's tail output tokens (those produced after the last + window-boundary sweep) and drop its tracking entry.""" + if not self.enabled: + return + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.pop(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len + + def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): + if self.enabled: + self.window_input_total += input_len + self.window_cache_total += cache_len + self.global_input_total += input_len + self.global_cache_total += cache_len + self.global_mtp_output_total += output_len + self.global_mtp_accepted_total += mtp_accepted + + def _get_gpu_status_for_debug(self) -> str: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + timeout=2, + ) + except (OSError, subprocess.SubprocessError) as e: + return f"gpu=unavailable({e.__class__.__name__})" + + gpu_infos = [] + for line in result.stdout.splitlines(): + parts = [part.strip() for part in line.split(",")] + if len(parts) != 4: + continue + gpu_index, util, mem_used, mem_total = parts + try: + mem_used_mb = float(mem_used) + mem_total_mb = float(mem_total) + mem_ratio = mem_used_mb / mem_total_mb * 100 if mem_total_mb > 0 else 0.0 + mem_used_gb = mem_used_mb / 1024 + mem_total_gb = mem_total_mb / 1024 + gpu_infos.append( + f"{gpu_index}(util={float(util):.0f}%,mem={mem_ratio:.1f}%," + f"used={mem_used_gb:.1f}GiB/{mem_total_gb:.1f}GiB)" + ) + except ValueError: + continue + if not gpu_infos: + return "gpu=unavailable(empty)" + return "gpu=[" + ";".join(gpu_infos) + "]" + + def maybe_print( + self, + running_batch, + req_queue, + read_only_statics_mem_manager, + paused_req_num=0, + radix_cache_client=None, + disable_dynamic_prompt_cache=False, + ): + if not self.enabled: + return + now = time.time() + elapsed = now - self.last_print_time + if elapsed < self.interval: + return + + # Single bulk sweep at the window boundary: account for output tokens produced + # by every still-running req since the previous boundary, and refresh their + # snapshots. Reqs that finished in this window already settled via discard_req. + if running_batch is not None: + for req in running_batch.reqs: + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.get(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len + self._req_last_output_len[req.request_id] = cur_out_len + + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed + input_tps = self.prompt_tokens / elapsed + output_tps = self.output_tokens / elapsed + + running = len(running_batch.reqs) if running_batch else 0 + queued = req_queue.get_wait_req_num() + + # kv_used: physical KV memory usage (includes prefix cache tree occupancy) + # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens + kv_used_list = [] + kv_used_no_cache_list = [] + for dp_i in range(self.dp_size_in_node): + unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i) + used = self.max_total_token_num - unrefed + kv_used_list.append(used / self.max_total_token_num) + if not disable_dynamic_prompt_cache and radix_cache_client is not None: + cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i) + kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num) + else: + kv_used_no_cache_list.append(used / self.max_total_token_num) + avg_kv_used = sum(kv_used_list) / len(kv_used_list) + avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list) + + window_cache_hit_rate = ( + (self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0 + ) + global_cache_hit_rate = ( + (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 + ) + + kv_pct = avg_kv_used * 100 + kv_pct_no_cache = avg_kv_used_no_cache * 100 + + log_parts = [ + f"router_status(window={elapsed:.1f}s)", + f"throughput(total={total_tps:.1f},input={input_tps:.1f},output={output_tps:.1f})", + f"req(running={running},waiting={queued},paused={paused_req_num})", + f"kv(used={kv_pct_no_cache:.1f}%)", + f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)", + ] + + if self.global_mtp_accepted_total > 0: + decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total + avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) + log_parts.append( + f"mtp(avg_tokens_per_step={avg_mtp_len:.2f}," + f"accepted={self.global_mtp_accepted_total},output={self.global_mtp_output_total})" + ) + + self.status_logger.info(" | ".join(log_parts)) + if logger.isEnabledFor(logging.DEBUG): + kv_unrefed_prefix_cache_pct = max(0.0, kv_pct - kv_pct_no_cache) + debug_parts = [ + "router_status_debug", + f"kv_physical={kv_pct:.1f}%", + f"kv_unrefed_prefix_cache={kv_unrefed_prefix_cache_pct:.1f}%", + f"throughput_tokens(input={self.prompt_tokens},output={self.output_tokens})", + f"gpu_cache_tokens(window={self.window_cache_total}/{self.window_input_total}," + f"global={self.global_cache_total}/{self.global_input_total})", + f"tracked_output_reqs={len(self._req_last_output_len)}", + self._get_gpu_status_for_debug(), + ] + logger.debug(" | ".join(debug_parts)) + + self.prompt_tokens = 0 + self.output_tokens = 0 + self.window_input_total = 0 + self.window_cache_total = 0 + self.last_print_time = now class RouterStatics: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf681..ef86378533 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -176,7 +176,7 @@ async def loop_for_netio_req(self): while True: recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) if isinstance(recv_req, GroupReqIndexes): - logger.info( + logger.debug( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5cf..6df36c3aed 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -4,24 +4,40 @@ import logging import sys import os -import time from typing import Optional _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" -_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") +_STATUS_FORMAT = "%(levelname)s [%(asctime)s] %(message)s" + +_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info") _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) +_RESET = "\033[0m" +_LEVEL_COLORS = { + logging.DEBUG: "\033[36m", # cyan + logging.INFO: "\033[32m", # green + logging.WARNING: "\033[33m", # yellow + logging.ERROR: "\033[31m", # red + logging.CRITICAL: "\033[1;31m", # bold red +} + class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + """Adds logging prefix to newlines to align multi-line messages, with optional color on levelname.""" - def __init__(self, fmt, datefmt=None): + def __init__(self, fmt, datefmt=None, use_color=False): logging.Formatter.__init__(self, fmt, datefmt) + self.use_color = use_color def format(self, record): + if self.use_color: + color = _LEVEL_COLORS.get(record.levelno, "") + if color: + record = logging.makeLogRecord(record.__dict__) + record.levelname = color + record.levelname + _RESET msg = logging.Formatter.format(self, record) if record.message != "": parts = msg.split(record.message) @@ -39,7 +55,9 @@ def _setup_logger(): _root_logger.setLevel(_LOG_LEVEL) global _default_handler global _default_file_handler - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _use_color = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + color_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=_use_color) + plain_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) if _default_handler is None: _default_handler = logging.StreamHandler(sys.stdout) @@ -55,10 +73,10 @@ def _setup_logger(): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) - _default_file_handler.setFormatter(fmt) + _default_file_handler.setFormatter(plain_fmt) _root_logger.addHandler(_default_file_handler) - _default_handler.setFormatter(fmt) + _default_handler.setFormatter(color_fmt) # Setting this will avoid the message # being propagated to the parent logger. _root_logger.propagate = False @@ -89,29 +107,28 @@ def init_logger(name: str): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _inference_log_file_handler[pid] = logging.FileHandler(_LOG_DIR + f"/process.{pid}.log") _inference_log_file_handler[pid].setLevel(_LOG_LEVEL) - _inference_log_file_handler[pid].setFormatter(NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)) + _inference_log_file_handler[pid].setFormatter( + NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) + ) _root_logger.addHandler(_inference_log_file_handler[pid]) logger.addHandler(_inference_log_file_handler[pid]) logger.propagate = False return logger -_log_time_mark_dict = {} - - -def log_time_ready(mark_name, time_count: int): - """ - time_count 间隔时间超过多少s调用该函数会返回True,否则返回False - 用于控制一些日志输出的频率 - """ - global _log_time_mark_dict - - if mark_name not in _log_time_mark_dict: - _log_time_mark_dict[mark_name] = time.time() - return False - cur_time_mark = time.time() - if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count: - _log_time_mark_dict[mark_name] = cur_time_mark - return True - else: - return False +def init_system_status_logger(name: str): + logger = logging.getLogger(f"lightllm.status.{name}") + if not logger.handlers: + logger.setLevel(logging.INFO) + fmt = logging.Formatter(_STATUS_FORMAT, datefmt=_DATE_FORMAT) + handler = logging.StreamHandler(sys.stdout) + handler.flush = sys.stdout.flush + handler.setFormatter(fmt) + logger.addHandler(handler) + if _LOG_DIR is not None: + file_handler = logging.FileHandler(os.path.join(_LOG_DIR, f"status.{name}.log")) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(fmt) + logger.addHandler(file_handler) + logger.propagate = False + return logger diff --git a/unit_tests/server/config_server/test_access_log.py b/unit_tests/server/config_server/test_access_log.py new file mode 100644 index 0000000000..4978239806 --- /dev/null +++ b/unit_tests/server/config_server/test_access_log.py @@ -0,0 +1,14 @@ +from fastapi.testclient import TestClient + + +def test_config_server_emits_access_log(monkeypatch): + from lightllm.server.config_server import api_http + + messages = [] + monkeypatch.setattr(api_http.logger, "info", lambda msg, *args, **kwargs: messages.append(str(msg))) + + with TestClient(api_http.app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert any("GET /health 200" in message for message in messages) diff --git a/unit_tests/server/httpserver/test_prompt_length_guard.py b/unit_tests/server/httpserver/test_prompt_length_guard.py new file mode 100644 index 0000000000..f03a59a324 --- /dev/null +++ b/unit_tests/server/httpserver/test_prompt_length_guard.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace + +import pytest + +from lightllm.server.httpserver.manager import HttpServerManager +from lightllm.server.httpserver_for_pd_master.manager import HttpServerManagerForPDMaster + + +class TokenizerMustNotRun: + vocab_size = 32000 + + def encode(self, *args, **kwargs): + raise AssertionError("tokenizer should not run for oversized text prompts") + + +def _fake_manager(): + return SimpleNamespace( + max_req_total_len=4, + tokenizer=TokenizerMustNotRun(), + args=SimpleNamespace(max_image_token_count=1024), + ) + + +def _empty_multimodal_params(): + return SimpleNamespace(images=[], audios=[]) + + +@pytest.mark.parametrize("manager_cls", [HttpServerManager, HttpServerManagerForPDMaster]) +def test_tokens_rejects_oversized_text_prompt_before_tokenization(manager_cls): + prompt = "x" * 33 + + with pytest.raises(ValueError, match="prompt text length 33 exceeds the character limit 32"): + manager_cls.tokens(_fake_manager(), prompt, _empty_multimodal_params(), SimpleNamespace()) diff --git a/unit_tests/server/test_qwen3_coder_stream_fc.py b/unit_tests/server/test_qwen3_coder_stream_fc.py new file mode 100644 index 0000000000..f72a642b11 --- /dev/null +++ b/unit_tests/server/test_qwen3_coder_stream_fc.py @@ -0,0 +1,194 @@ +"""Unit tests for Qwen3-Coder XML streaming tool-call parsing. + +These drive ``Qwen3CoderDetector.parse_streaming_increment`` directly (no server), +reassembling the per-tool argument string exactly the way ``api_openai.py`` does for +streamed responses, including the ``finish_reason == "stop"`` reconciliation. The key +invariant under test: the reassembled arguments are always valid JSON equal to what the +one-shot ``detect_and_parse`` would produce, for every chunk boundary. +""" + +import json +import pytest + +from lightllm.server.api_models import Function, Tool +from lightllm.server.function_call_parser import Qwen3CoderDetector + +CHUNK_SIZES = [1, 2, 3, 5, 13, 10_000] + + +def _tool(name, properties): + return Tool( + type="function", + function=Function(name=name, description="", parameters={"type": "object", "properties": properties}), + ) + + +def _stream_and_reassemble(text, tools, chunk): + """Feed ``text`` to the detector in fixed-size chunks and rebuild the client-visible + tool calls the way api_openai.py stream_results does (stop-rewrite on the last chunk).""" + det = Qwen3CoderDetector() + chunks = [text[i : i + chunk] for i in range(0, len(text), chunk)] + per_tool = {} + for ci, piece in enumerate(chunks): + result = det.parse_streaming_increment(piece, tools) + is_last = ci == len(chunks) - 1 + for call in result.calls: + ti = call.tool_index + per_tool.setdefault(ti, {"name": None, "args": ""}) + if call.name is not None: + per_tool[ti]["name"] = call.name + params = call.parameters + if is_last and params: + # Mirror api_openai.py:559-575 (REPLACE semantics). + latest_delta_len = len(params) + expected = json.dumps(det.prev_tool_call_arr[ti].get("arguments", {}), ensure_ascii=False) + actual = det.streamed_args_for_tool[ti] + if latest_delta_len > 0: + actual = actual[:-latest_delta_len] + params = expected.replace(actual, "", 1) + if params: + per_tool[ti]["args"] += params + return det, per_tool + + +def _assert_tool_calls(text, tools, expected, chunk_sizes=CHUNK_SIZES): + """expected: {tool_index: (name, args_dict)}.""" + for chunk in chunk_sizes: + det, per_tool = _stream_and_reassemble(text, tools, chunk) + assert len(per_tool) == len(expected), f"chunk={chunk}: tool count {len(per_tool)} != {len(expected)}" + for ti, (name, args) in expected.items(): + got = per_tool[ti] + assert got["name"] == name, f"chunk={chunk}: tool {ti} name {got['name']!r} != {name!r}" + parsed = json.loads(got["args"]) # must be valid JSON + assert parsed == args, f"chunk={chunk}: tool {ti} args {parsed!r} != {args!r}" + + +@pytest.mark.parametrize("chunk", CHUNK_SIZES) +def test_single_string_param(chunk): + text = ( + "\n\n\n" + "San Francisco\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("get_weather", {"location": {"type": "string"}})], + {0: ("get_weather", {"location": "San Francisco"})}, + [chunk], + ) + + +def test_array_param_compact_spacing(): + # Regression: a non-string value whose raw text ("[1,2]") differs from json.dumps + # ("[1, 2]") used to break the streamed-args prefix invariant -> duplicated/invalid JSON. + text = "\n\n\n[1,2]\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"nums": {"type": "array"}})], {0: ("calc", {"nums": [1, 2]})}) + + +def test_number_param_reformatted(): + # "1.0" is parsed to int 1 and json.dumps'd as "1"; the stream must agree. + text = "\n\n\n1.0\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"v": {"type": "number"}})], {0: ("calc", {"v": 1})}) + + +def test_boolean_param(): + text = "\n\n\ntrue\n\n\n" + _assert_tool_calls(text, [_tool("set", {"flag": {"type": "boolean"}})], {0: ("set", {"flag": True})}) + + +def test_object_param(): + text = '\n\n\n{"a":1,"b":[2,3]}\n\n\n' + _assert_tool_calls(text, [_tool("f", {"cfg": {"type": "object"}})], {0: ("f", {"cfg": {"a": 1, "b": [2, 3]}})}) + + +def test_two_params_mixed_types(): + text = ( + "\n\n\nNYC\n\n" + "\n3\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("f", {"city": {"type": "string"}, "days": {"type": "integer"}})], + {0: ("f", {"city": "NYC", "days": 3})}, + ) + + +def test_multiline_string_value(): + text = ( + "\n\n\nline1\nline2\nline3\n" "\n\n" + ) + _assert_tool_calls(text, [_tool("f", {"code": {"type": "string"}})], {0: ("f", {"code": "line1\nline2\nline3"})}) + + +def test_string_with_json_special_chars(): + text = '\n\n\nsay "hi"\\path\n\n\n' + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": 'say "hi"\\path'})}) + + +def test_empty_string_value(): + text = "\n\n\n\n\n\n" + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": ""})}) + + +def test_no_param_function(): + text = "\n\n\n" + _assert_tool_calls(text, [_tool("ping", {})], {0: ("ping", {})}) + + +def test_two_separate_tool_call_blocks(): + text = ( + "\n\n\nhi\n\n\n\n" + "\n\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_two_functions_in_one_block(): + # Regression: used to raise IndexError on the second function in a single block. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_undefined_then_valid_in_same_block(): + # Regression: an undefined first function used to discard the whole block, dropping + # the valid call that followed it. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls(text, [_tool("valid", {"y": {"type": "string"}})], {0: ("valid", {"y": "yo"})}) + + +def test_truncated_call_missing_function_close(): + # Regression: a typed value with no before used to leave the + # streamed args unterminated (missing closing brace). + text = "\n\n\n0.50\n\n" + _assert_tool_calls(text, [_tool("calc", {"x": {"type": "number"}})], {0: ("calc", {"x": 0.5})}) + + +def test_streaming_matches_non_stream(): + # The reassembled streamed args must equal the one-shot detect_and_parse output. + tools = [_tool("f", {"city": {"type": "string"}, "n": {"type": "integer"}, "tags": {"type": "array"}})] + text = ( + "\n\n\nLondon\n\n" + '\n7\n\n\n["a","b"]\n\n\n' + ) + oneshot = Qwen3CoderDetector().detect_and_parse(text, tools) + expected_args = json.loads(oneshot.calls[0].parameters) + _assert_tool_calls(text, tools, {0: ("f", expected_args)}) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"]))