From 739e4910069bc2f98720bedd80ca660644c9b179 Mon Sep 17 00:00:00 2001 From: luukunn <981429396@qq.com> Date: Mon, 30 Mar 2026 16:41:21 +0800 Subject: [PATCH 1/3] fix tool parser --- .../ernie_45_vl_thinking_tool_parser.py | 4 - .../tool_parsers/ernie_x1_tool_parser.py | 468 +++++++++--------- 2 files changed, 223 insertions(+), 249 deletions(-) diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py index 79eb3058b15..4e6cc93cbaa 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py @@ -246,10 +246,6 @@ def extract_tool_calls_streaming( if self.valid is not None and not self.valid: return DeltaMessage(content=delta_text) - # Skip empty chunks - if len(delta_text.strip()) == 0: - return None - try: delta = None # Use buffer to accumulate delta_text content diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index 3427a1814fa..60371c3616f 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -18,7 +18,6 @@ import re import uuid from collections.abc import Sequence -from typing import Union import partial_json_parser @@ -28,6 +27,8 @@ def random_tool_call_id() -> str: return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" +from partial_json_parser.core.options import Allow + from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, @@ -41,40 +42,44 @@ def random_tool_call_id() -> str: ToolParser, ToolParserManager, ) -from fastdeploy.utils import data_processor_logger +from fastdeploy.utils import data_processor_logger as logger @ToolParserManager.register_module("ernie-x1") class ErnieX1ToolParser(ToolParser): """ - Tool parser for Ernie model version 4.5.1. This parser handles tool calls with newline formats. """ def __init__(self, tokenizer): + """ + Ernie thinking model format: + abc\n\n\n\n\ndef\n\n + """ super().__init__(tokenizer) - + self.current_tool_name_sent = False self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list - self.buffer: str = "" # buffer for accumulating unprocessed streaming content - self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas - self.tool_call_start_token: str = "" - self.tool_call_end_token: str = "" + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.think_end_token = "" + self.response_start_token: str = "" + self.response_end_token: str = "" + self.tool_call_start_token = "" + self.tool_call_end_token = "" - self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) - self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!" - ) + self.tool_call_regex = re.compile(r"\s*(?P\{.*?\})\s*", re.DOTALL) if not self.model_tokenizer: raise ValueError( - "The model tokenizer must be passed to the ToolCallParser constructor during construction." + "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) + self.think_end_token_id = self.vocab.get(self.think_end_token) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. @@ -88,144 +93,28 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) """ try: + tool_call_json_list = self.tool_call_regex.findall(model_output) tool_calls = [] - - # Check for invalid tags before tool calls - if re.search(r"[\s\S]*?\s*(?=)", model_output): - data_processor_logger.error("Invalid format: tags found before ") - return ExtractedToolCallInformation(tools_called=False, content=model_output) - - function_call_arr = [] - remaining_text = model_output - - while True: - # Find the next - tool_call_pos = remaining_text.find("") - if tool_call_pos == -1: - break - - # Extract content after - tool_content_start = tool_call_pos + len("") - tool_content_end = remaining_text.find("", tool_content_start) - - tool_json = "" - if tool_content_end == -1: - # Processing unclosed tool_call block (truncated case) - tool_json = remaining_text[tool_content_start:].strip() - remaining_text = "" # No more content to process - else: - # Processing closed block - tool_json = remaining_text[tool_content_start:tool_content_end].strip() - remaining_text = remaining_text[tool_content_end + len("") :] - - if not tool_json: - continue - - # Process tool_json - tool_json = tool_json.strip() - if not tool_json.startswith("{"): - tool_json = "{" + tool_json - if not tool_json.endswith("}"): - tool_json = tool_json + "}" - - try: - # Parsing strategy: First try standard json.loads - try: - tool_data = json.loads(tool_json) - - if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data: - function_call_arr.append( - { - "name": tool_data["name"], - "arguments": tool_data["arguments"], - "_is_complete": True, # Mark as complete - } - ) - continue - except json.JSONDecodeError: - pass - - # Try partial_json_parser when standard parsing fails - from partial_json_parser.core.options import Allow - - try: - tool_data = {} - flags = Allow.ALL & ~Allow.STR - - # Parse the name field - name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) - if name_match: - tool_data["name"] = name_match.group(1) - - # Parse the arguments field - args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) - if args_match: - try: - tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags) - except Exception as e: - data_processor_logger.debug(f"Failed to parse tool arguments: {e}") - tool_data["arguments"] = None - - if isinstance(tool_data, dict): - function_call_arr.append( - { - "name": tool_data.get("name", ""), - "arguments": tool_data.get("arguments", {}), - "_is_partial": True, # Mark as partial - } - ) - except Exception as e: - data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") - continue - except Exception as e: - data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") - continue - - if not function_call_arr: - data_processor_logger.error("No valid tool calls found") - return ExtractedToolCallInformation(tools_called=False, content=model_output) - - tool_calls = [] - all_complete = True # Initialize as all complete - - for tool_call in function_call_arr: - # Set flags - is_complete = tool_call.get("_is_complete", False) - is_partial = tool_call.get("_is_partial", False) - - # If any tool call is incomplete or partial, mark all_complete as False - if not is_complete or is_partial: - all_complete = False - - # Process arguments - tool_args = tool_call.get("arguments", {}) - if not isinstance(tool_args, dict): - tool_args = {} - - try: - args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}" - except: - args_str = "{}" - + for tool_call_json in tool_call_json_list: + tool_call_dict = json.loads(tool_call_json) + args_str = json.dumps(tool_call_dict.get("arguments", {}), ensure_ascii=False) tool_calls.append( ToolCall( type="function", id=random_tool_call_id(), function=FunctionCall( - name=tool_call.get("name", ""), + name=tool_call_dict.get("name", ""), arguments=args_str, ), ) ) - - # Only return tools_called=True if all tool calls are complete return ExtractedToolCallInformation( - tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" + tools_called=True, + tool_calls=tool_calls, ) - - except Exception as e: - data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}") - return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + except Exception: + logger.warning("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) def extract_tool_calls_streaming( self, @@ -235,114 +124,203 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - request: dict, - ) -> Union[DeltaMessage, None]: - - if self.tool_call_start_token_id not in current_token_ids: + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if self.tool_call_start_token not in current_text: + logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - # Skip empty chunks - if len(delta_text.strip()) == 0: - return None try: - delta = None - # Use buffer to accumulate delta_text content - self.buffer += delta_text - - # Process the buffer content - if "" in delta_text: - self.current_tool_id = ( - max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 + prev_tool_start_count = previous_text.count(self.tool_call_start_token) + prev_tool_end_count = previous_text.count(self.tool_call_end_token) + cur_tool_start_count = current_text.count(self.tool_call_start_token) + cur_tool_end_count = current_text.count(self.tool_call_end_token) + tool_call_portion = None + text_portion = None + + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0].rstrip() ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() + + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + + if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + self.current_tool_id += 1 self.current_tool_name_sent = False - if len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") - data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") - - # 1. Try to parse the name field - if not self.current_tool_name_sent and '"name"' in self.buffer: - name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer) - if name_match: - name = name_match.group(1) - if name: - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=random_tool_call_id(), - function=DeltaFunctionCall(name=name).model_dump(exclude_none=True), - ) - ] - ) - # Delete the processed name part from the buffer - self.buffer = self.buffer[name_match.end() :] - self.current_tool_name_sent = True - return delta - # 2. Processing arguments field - if '"arguments"' in self.buffer: - args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer) - if args_match: - args_content = args_match.group(1) - try: - # Check if arguments field is complete by bracket matching - if "}}" in args_content: - matched_pos = -1 - for i, ch in enumerate(delta_text): - if ch == "{": - self.bracket_counts["total_l"] += 1 - elif ch == "}": - self.bracket_counts["total_r"] += 1 - - if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: - matched_pos = i - break - - if matched_pos >= 0: - # Clean up bracket counts for next tool call - truncate_text = delta_text[: matched_pos + 1] - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall(arguments=truncate_text).model_dump( - exclude_none=True - ), - ) - ] - ) - self.buffer = self.buffer[args_match.end() :] - return delta - else: - # No complete match yet - return None - else: - # Return partial arguments - for ch in delta_text: - if ch == "{": - self.bracket_counts["total_l"] += 1 - elif ch == "}": - self.bracket_counts["total_r"] += 1 - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True), - ) - ] - ) - return delta - except Exception as e: - data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}") + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + text_portion = None + + elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count: + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") + if diff: + diff = diff.encode("utf-8").decode("unicode_escape") if diff is str else diff + if '"}' not in delta_text: return None - if "" in self.buffer: - end_pos = self.buffer.find("") - self.buffer = self.buffer[end_pos + len("") :] + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True), + ) + ] + ) - self.streamed_args_for_tool.append("") + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + current_tool_call = ( + partial_json_parser.loads(tool_call_portion or "{}", flags) if tool_call_portion else None + ) + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug("not enough tokens to parse into JSON yet") + return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None + + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: str | None = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), + ) + ] + ) + else: + return None + + if tool_call_portion is None: + delta = DeltaMessage(content=delta_text) if text_portion is not None else None + return delta + + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.") + delta = None + + elif cur_arguments and not prev_arguments: + function_name = current_tool_call.get("name") + match = re.search( + r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), + re.DOTALL, + ) + if match: + cur_arguments_json = match.group(1) + else: + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) + + logger.debug("finding %s in %s", delta_text, cur_arguments_json) + + if delta_text not in cur_arguments_json: + return None + args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(delta_text) + + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", arguments_delta) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=arguments_delta).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + try: + json.loads(tool_call_portion) + is_complete_json = True + except Exception: + is_complete_json = False + + if ( + isinstance(delta_text, str) + and len(delta_text.rstrip()) >= 1 + and delta_text.rstrip()[-1] == "}" + and is_complete_json + ): + delta_text = delta_text.rstrip()[:-1] + + logger.debug("got diff %s", delta_text) + if is_complete_json and delta_text.strip() == "": + return None + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += delta_text + + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) return delta - except Exception as e: - data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}") - return None + except Exception: + logger.warning("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. From cb3a55ff9c62f31eeb3d861bfd6af471f986380e Mon Sep 17 00:00:00 2001 From: luukunn <981429396@qq.com> Date: Mon, 30 Mar 2026 17:03:02 +0800 Subject: [PATCH 2/3] fix unit test --- .../tool_parsers/test_ernie_x1_tool_parser.py | 137 ------------------ 1 file changed, 137 deletions(-) delete mode 100644 tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py deleted file mode 100644 index c8a24c5707d..00000000000 --- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import unittest - -from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage -from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import ( - ErnieX1ToolParser, -) - - -class DummyTokenizer: - """Dummy tokenizer with vocab containing tool_call tokens""" - - def __init__(self): - self.vocab = {"": 1, "": 2} - - def get_vocab(self): - return self.vocab - - -class TestErnieX1ToolParser(unittest.TestCase): - def setUp(self): - self.tokenizer = DummyTokenizer() - self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer) - self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}]) - - # ---------------- Batch extraction tests ---------------- - - def test_extract_tool_calls_complete(self): - """Test normal extraction of complete tool_call JSON""" - output = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertTrue(result.tools_called) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") - - def test_extract_tool_calls_no_toolcall(self): - """Test when no tool_call tags are present""" - output = "no tool call here" - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertFalse(result.tools_called) - - def test_extract_tool_calls_exception(self): - """Completely broken JSON triggers the exception branch""" - output = "not json at all{{{" - result = self.parser.extract_tool_calls(output, self.dummy_request) - self.assertFalse(result.tools_called) - - def test_extract_tool_calls_partial_json_parser_failure(self): - """Test partial_json_parser failure path for arguments (L165-166). - json.loads fails on malformed JSON, partial_json_parser.loads also fails on deeply broken args. - Partial result has _is_partial=True so tools_called=False, but tool_calls is populated.""" - # Malformed JSON: valid name but arguments is a bare invalid token - # that breaks both json.loads and partial_json_parser - output = '{"name": "test", "arguments": @@@INVALID@@@}' - result = self.parser.extract_tool_calls(output, self.dummy_request) - # _is_partial=True → tools_called=False, but tool_calls list is populated - self.assertFalse(result.tools_called) - self.assertIsNotNone(result.tool_calls) - self.assertEqual(result.tool_calls[0].function.name, "test") - # arguments=None → converted to {} → serialized as "{}" - self.assertEqual(result.tool_calls[0].function.arguments, "{}") - - def test_partial_json_parser_exception_triggers_debug_log(self): - """Malformed JSON + partial_json_parser failure exercises L165-166 exactly.""" - # Unclosed string in arguments breaks both json.loads and partial_json_parser - output = '{"name": "my_tool", "arguments": {"key": "unterminated}' - result = self.parser.extract_tool_calls(output, self.dummy_request) - # Partial parse → tools_called=False but tool_calls has entries - self.assertFalse(result.tools_called) - self.assertIsNotNone(result.tool_calls) - self.assertEqual(result.tool_calls[0].function.name, "my_tool") - - # ---------------- Streaming extraction tests ---------------- - - def test_streaming_no_toolcall(self): - """Streaming extraction returns normal DeltaMessage when no toolcall tag""" - result = self.parser.extract_tool_calls_streaming( - "", "abc", "abc", [], [], [], self.dummy_request.model_dump() - ) - self.assertIsInstance(result, DeltaMessage) - self.assertEqual(result.content, "abc") - - def test_streaming_skip_empty_chunk(self): - """Streaming extraction skips empty chunks""" - result = self.parser.extract_tool_calls_streaming( - "", "", " ", [], [1], [1], self.dummy_request.model_dump() - ) - self.assertIsNone(result) - - def test_streaming_new_toolcall_and_name(self): - """Streaming extraction detects new toolcall and extracts name""" - delta = self.parser.extract_tool_calls_streaming( - "", "", '{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump() - ) - self.assertIsNotNone(delta) - self.assertEqual(delta.tool_calls[0].function.name, "get_weather") - - def test_streaming_partial_arguments(self): - """Streaming extraction yields partial arguments deltas""" - text = '"arguments": {"location":' - delta = self.parser.extract_tool_calls_streaming( - "", "" + text, text, [], [1], [1], self.dummy_request.model_dump() - ) - self.assertIsInstance(delta, DeltaMessage) - self.assertIn("arguments", delta.tool_calls[0].function.arguments) - - def test_streaming_complete_arguments_and_end(self): - """Streaming extraction completes arguments with brackets matched and closes tool_call""" - text = '"arguments": {"location": "Beijing"}}' - delta = self.parser.extract_tool_calls_streaming( - "", "" + text, text, [], [1], [1], self.dummy_request.model_dump() - ) - self.assertIsInstance(delta, DeltaMessage) - # Also simulate closing tag - end_delta = self.parser.extract_tool_calls_streaming( - "", "", "", [], [2], [2], self.dummy_request.model_dump() - ) - self.assertIsNotNone(end_delta) - self.assertEqual(end_delta.content, "") - - -if __name__ == "__main__": - unittest.main() From 818046ebe3d7a58289fc6834182a786589d26625 Mon Sep 17 00:00:00 2001 From: luukunn <981429396@qq.com> Date: Mon, 30 Mar 2026 20:05:29 +0800 Subject: [PATCH 3/3] add unit test --- .../tool_parsers/ernie_x1_tool_parser.py | 1 - .../tool_parsers/test_ernie_x1_tool_parser.py | 854 ++++++++++++++++++ 2 files changed, 854 insertions(+), 1 deletion(-) create mode 100644 tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index 60371c3616f..ba162b2d516 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -181,7 +181,6 @@ def extract_tool_calls_streaming( return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = diff.encode("utf-8").decode("unicode_escape") if diff is str else diff if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py new file mode 100644 index 00000000000..5cde01c3933 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -0,0 +1,854 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import patch + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import ( + ErnieX1ToolParser, +) + + +class TestErnieX1ToolParser(unittest.TestCase): + def setUp(self): + class DummyTokenizer: + def __init__(self): + self.vocab = { + "": 1, + "": 2, + "": 3, + "": 4, + "": 5, + } + + def get_vocab(self): + return self.vocab + + self.tokenizer = DummyTokenizer() + self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer) + self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}]) + + def _new_parser(self): + """Create a fresh parser to avoid state pollution between tests.""" + + class DummyTokenizer: + def __init__(self): + self.vocab = { + "": 1, + "": 2, + "": 3, + "": 4, + "": 5, + } + + def get_vocab(self): + return self.vocab + + return ErnieX1ToolParser(tokenizer=DummyTokenizer()) + + # ==================== __init__ tests (lines 60-81) ==================== + + def test_init_sets_tokens_and_ids(self): + """Cover lines 60-81: verify all token attributes and vocab lookups""" + p = self.parser + self.assertFalse(p.current_tool_name_sent) + self.assertEqual(p.prev_tool_call_arr, []) + self.assertEqual(p.current_tool_id, -1) + self.assertEqual(p.streamed_args_for_tool, []) + self.assertEqual(p.think_end_token, "") + self.assertEqual(p.response_start_token, "") + self.assertEqual(p.response_end_token, "") + self.assertEqual(p.tool_call_start_token, "") + self.assertEqual(p.tool_call_end_token, "") + self.assertIsNotNone(p.tool_call_regex) + self.assertEqual(p.think_end_token_id, 3) + self.assertEqual(p.response_start_token_id, 4) + self.assertEqual(p.response_end_token_id, 5) + self.assertEqual(p.tool_call_start_token_id, 1) + self.assertEqual(p.tool_call_end_token_id, 2) + + def test_init_raises_without_tokenizer(self): + """Cover lines 72-75: ValueError when tokenizer is falsy""" + with self.assertRaises(ValueError): + ErnieX1ToolParser(tokenizer=None) + + # ==================== extract_tool_calls tests (lines 96-117) ==================== + + def test_extract_tool_calls_single(self): + """Cover lines 96-114: single complete tool call""" + output = '{"name": "get_weather", "arguments": {"location": "北京"}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(len(result.tool_calls), 1) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertIn("北京", result.tool_calls[0].function.arguments) + + def test_extract_tool_calls_multiple(self): + """Cover lines 98-100: multiple tool calls""" + output = ( + '{"name": "get_weather", "arguments": {"location": "北京"}}' + '{"name": "get_time", "arguments": {"timezone": "UTC"}}' + ) + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(len(result.tool_calls), 2) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertEqual(result.tool_calls[1].function.name, "get_time") + + def test_extract_tool_calls_no_arguments(self): + """Cover line 100: tool call with no arguments defaults to {}""" + output = '{"name": "list_items"}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.arguments, "{}") + + def test_extract_tool_calls_nested_arguments(self): + """Cover regex with nested braces in arguments""" + output = '{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertIn("$gt", result.tool_calls[0].function.arguments) + + def test_extract_tool_calls_with_whitespace(self): + """Cover regex with whitespace around JSON""" + output = ' \n{"name": "fn", "arguments": {}} \n' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "fn") + + def test_extract_tool_calls_no_match(self): + """Cover lines 96, 111-114: no tool_call tags -> tools_called=True with empty list""" + output = "just plain text" + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(len(result.tool_calls), 0) + + def test_extract_tool_calls_invalid_json(self): + """Cover lines 115-117: malformed JSON triggers exception branch""" + output = "{invalid json}" + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + self.assertEqual(result.content, output) + + def test_extract_tool_calls_exception(self): + """Cover lines 115-117: forced exception via mock""" + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", + side_effect=Exception("boom"), + ): + output = '{"name": "get_weather", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + self.assertEqual(result.content, output) + + # ==================== extract_tool_calls_streaming tests ==================== + + # --- Line 129-131: no tool_call_start_token in current_text --- + + def test_streaming_no_tool_call_token(self): + """Cover lines 129-131: no in current_text returns content delta""" + result = self.parser.extract_tool_calls_streaming("", "hello world", "world", [], [], [], self.dummy_request) + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.content, "world") + self.assertIsNone(result.tool_calls) + + # --- Lines 141-147: balanced start/end counts, text generation after tool call --- + + def test_streaming_balanced_counts_text_after_tool(self): + """Cover lines 134-147: start==end, prev_end==cur_end, end not in delta -> text content""" + prev = "{}" + cur = "{} some text" + delta = " some text" + result = self.parser.extract_tool_calls_streaming(prev, cur, delta, [1, 2], [1, 2], [], self.dummy_request) + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.content, delta) + + # --- Lines 149-156: tool_call_end_token in delta_text --- + + def test_streaming_end_token_in_delta(self): + """Cover lines 149-156: appears in delta""" + parser = self._new_parser() + # First, start a tool call + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # Now stream arguments + parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn", "arguments": {"k": "v', + ', "arguments": {"k": "v', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + # Close with end token in delta + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "v"}}', + '"}}', + [1, 10, 20], + [1, 10, 20, 2], + [2], + self.dummy_request, + ) + # Should handle end token + self.assertTrue(result is None or isinstance(result, DeltaMessage)) + + # --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) --- + + def test_streaming_new_tool_call_single_token(self): + """Cover lines 160-172 (len(delta_token_ids)==1): new tool start with single token""" + parser = self._new_parser() + result = parser.extract_tool_calls_streaming( + "", + "", + "", + [], + [1], + [1], + self.dummy_request, + ) + # tool_call_portion is None, current_tool_call is None, name not sent -> None + self.assertIsNone(result) + self.assertEqual(parser.current_tool_id, 0) + self.assertEqual(len(parser.streamed_args_for_tool), 1) + + def test_streaming_new_tool_call_multi_tokens(self): + """Cover lines 160-162 (len(delta_token_ids)>1): new tool start with content""" + parser = self._new_parser() + result = parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.name, "fn") + self.assertEqual(parser.current_tool_id, 0) + + # --- Lines 174-176: continuing inside existing tool (cur_start > cur_end, same start count) --- + + def test_streaming_continue_tool_call_no_name_yet(self): + """Cover lines 174-176, 220-222: partial JSON without name yet""" + parser = self._new_parser() + # Start tool call + parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Continue with partial content, no name parseable yet + result = parser.extract_tool_calls_streaming( + "", + '{"na', + '{"na', + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNone(result) + + def test_streaming_continue_tool_call_with_name(self): + """Cover lines 174-176, 223-235: name becomes available""" + parser = self._new_parser() + # Start tool call + parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Name appears + result = parser.extract_tool_calls_streaming( + "", + '{"name": "get_weather"', + '{"name": "get_weather"', + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertTrue(parser.current_tool_name_sent) + + # --- Lines 236-237: name not sent and function_name is None --- + + def test_streaming_no_function_name(self): + """Cover lines 236-237: parsed JSON has no 'name' field""" + parser = self._new_parser() + parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Send JSON without name field + result = parser.extract_tool_calls_streaming( + "", + '{"arguments": {"k": "v"}}', + '{"arguments": {"k": "v"}}', + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNone(result) + + # --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) --- + + def test_streaming_close_no_prev_tool_call(self): + """Cover lines 178-181: close branch with empty prev_tool_call_arr""" + parser = self._new_parser() + parser.prev_tool_call_arr = [] + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + result = parser.extract_tool_calls_streaming( + '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"}}', + "", + [1, 10], + [1, 10, 2], + [2], + self.dummy_request, + ) + self.assertIsNone(result) + + def test_streaming_close_with_remaining_diff(self): + """Cover lines 182-200: close with arguments diff that hasn't been streamed""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] + result = parser.extract_tool_calls_streaming( + '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"}}', + '"}}', + [1, 10], + [1, 10, 2], + [2], + self.dummy_request, + ) + self.assertIsNotNone(result) + self.assertIsNotNone(result.tool_calls) + + def test_streaming_close_with_diff_no_end_marker(self): + """Cover lines 184-185: close with arguments but no '"}' in delta_text""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] + # Simulate end token in delta but without '"}' pattern + # We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta + # so that we enter the elif at 178 + result = parser.extract_tool_calls_streaming( + '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"}} text', + " text", + [1, 10, 2], + [1, 10, 2, 30], + [30], + self.dummy_request, + ) + # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147) + self.assertIsInstance(result, DeltaMessage) + + def test_streaming_close_no_arguments(self): + """Cover lines 182-183: close branch where prev arguments is None/empty""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{"name": "fn"}] # no arguments key + result = parser.extract_tool_calls_streaming( + '{"name":"fn"}', + '{"name":"fn"}', + "}", + [1, 10], + [1, 10, 2], + [2], + self.dummy_request, + ) + # diff is None (no arguments), so falls through to partial_json_parser + self.assertTrue(result is None or isinstance(result, DeltaMessage)) + + # --- Lines 202-206: else branch (cur_start < cur_end, edge case) --- + + def test_streaming_else_branch(self): + """Cover lines 202-206: fall-through else branch""" + parser = self._new_parser() + parser.current_tool_name_sent = True + # Construct a scenario where cur_start < cur_end (more end tags than start) + prev = "" + cur = "" + delta = "" + result = parser.extract_tool_calls_streaming(prev, cur, delta, [1], [1, 2, 2], [2], self.dummy_request) + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.tool_calls, []) + + # --- Lines 208-218: partial_json_parser errors --- + + def test_streaming_malformed_json(self): + """Cover lines 213-215: MalformedJSON from partial parser""" + parser = self._new_parser() + parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Feed badly formed content + result = parser.extract_tool_calls_streaming( + "", + "{{{", + "{{{", + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNone(result) + + def test_streaming_json_decode_error(self): + """Cover lines 216-218: JSONDecodeError from partial parser""" + parser = self._new_parser() + parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", + side_effect=ValueError("bad json"), + ): + result = parser.extract_tool_calls_streaming( + "", + "bad", + "bad", + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNone(result) + + # --- Lines 239-241: tool_call_portion is None after name sent --- + + def test_streaming_tool_portion_none_with_text(self): + """Cover lines 239-241: tool_call_portion is None, text_portion is not None""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{}] + # Force tool_call_portion = None and text_portion = not None + # This happens when end_token is in delta (sets text_portion) but new tool start + # overrides tool_call_portion to None with single token + # Simulate: new tool start with single token AND end token in delta + # Actually, the simplest path: end token in delta sets text_portion, then new tool start + # sets tool_call_portion = None + # Let's use a different approach - directly test via the continuing branch + # where tool_call_portion remains None from the end_token path + result = parser.extract_tool_calls_streaming( + '{"name":"fn"}', + '{"name":"fn"}', + "", + [1, 10], + [1, 10, 2, 1], + [2, 1], + self.dummy_request, + ) + self.assertTrue(result is None or isinstance(result, DeltaMessage)) + + # --- Lines 243-244: append to prev_tool_call_arr --- + + def test_streaming_first_arguments_with_regex_match(self): + """Cover lines 243-244, 257-286: first arguments appear, regex matches""" + parser = self._new_parser() + # Start tool call and send name + parser.extract_tool_calls_streaming( + "", + '{"name": "get_weather"', + '{"name": "get_weather"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # Now stream arguments (first time) + # Key must be complete (closing quote) so partial_json_parser returns truthy arguments. + # delta must be a substring of the regex-extracted arguments portion (after "arguments":). + result = parser.extract_tool_calls_streaming( + '{"name": "get_weather"', + '{"name": "get_weather", "arguments": {"location": "bei', + '"bei', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + self.assertIsNotNone(result) + self.assertIsNotNone(result.tool_calls) + + def test_streaming_first_arguments_no_regex_match(self): + """Cover lines 266-267: regex doesn't match, fallback to json.dumps""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{}] + # Use tool_call_portion where key order differs so regex won't match + # (regex expects {"name":... at the start, but here "extra" comes first) + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", + return_value={"name": "fn", "extra": True, "arguments": {"k": "v"}}, + ): + result = parser.extract_tool_calls_streaming( + "", + '{"extra": true, "name": "fn", "arguments": {"k": "v"}}', + '"v"}', + [1], + [1, 10], + [10], + self.dummy_request, + ) + # regex fails on {"extra":... format, falls back to json.dumps + # delta '"v"}' is in json.dumps({"k": "v"}) = '{"k": "v"}' + self.assertIsNotNone(result) + self.assertIsNotNone(result.tool_calls) + + def test_streaming_first_arguments_delta_not_in_json(self): + """Cover lines 271-272: delta_text not found in cur_arguments_json""" + parser = self._new_parser() + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # Delta text that doesn't appear in the arguments JSON + result = parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn", "arguments": {"k": "v"}}', + "ZZZZZ", + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + self.assertIsNone(result) + + # --- Lines 249-251: no cur_arguments and no prev_arguments --- + + def test_streaming_no_arguments_at_all(self): + """Cover lines 249-251: both cur and prev arguments are empty/None""" + parser = self._new_parser() + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # Continue with name only, no arguments + result = parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn"}', + "}", + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + # prev_arguments=None, cur_arguments=None -> delta=None + # then prev_tool_call_arr updated and returns delta (which is None) + self.assertIsNone(result) + + # --- Lines 253-255: cur_arguments reset (impossible branch) --- + + def test_streaming_arguments_reset_mid_call(self): + """Cover lines 253-255: prev has arguments but cur doesn't (impossible case)""" + parser = self._new_parser() + parser.current_tool_id = 0 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] + # Feed content where cur has no arguments but prev does + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", + return_value={"name": "fn"}, + ): + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v"', + '{"name": "fn", "arguments": {"k": "v"}', + '"}', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + self.assertIsNone(result) + + # --- Lines 288-314: cur_arguments and prev_arguments both present --- + + def test_streaming_incremental_arguments_incomplete(self): + """Cover lines 288-314: both prev and cur have arguments, JSON incomplete""" + parser = self._new_parser() + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # First arguments - delta must appear in regex-extracted arguments portion + parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn", "arguments": {"k": "v', + '{"k": "v', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + # More argument tokens (both prev and cur have arguments now) + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "val', + "al", + [1, 10, 20], + [1, 10, 20, 30], + [30], + self.dummy_request, + ) + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.arguments, "al") + + def test_streaming_incremental_arguments_complete_json(self): + """Cover lines 289-305: complete JSON with trailing }""" + parser = self._new_parser() + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # First arguments - delta must appear in regex-extracted arguments portion + parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn", "arguments": {"k": "v', + '{"k": "v', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + # Complete with closing braces - both prev and cur have arguments + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "v"}}', + '"}}', + [1, 10, 20], + [1, 10, 20, 30], + [30], + self.dummy_request, + ) + # is_complete_json=True, delta ends with }, should strip trailing } + # After strip: '"' which is not empty, so returns DeltaMessage + self.assertIsNotNone(result) + self.assertIsInstance(result, DeltaMessage) + + def test_streaming_incremental_arguments_complete_empty_delta(self): + """Cover lines 304-305: complete JSON where delta becomes empty after strip""" + parser = self._new_parser() + parser.extract_tool_calls_streaming( + "", + '{"name": "fn"', + '{"name": "fn"', + [], + [1, 10], + [1, 10], + self.dummy_request, + ) + # First arguments with proper delta + parser.extract_tool_calls_streaming( + '{"name": "fn"', + '{"name": "fn", "arguments": {"k": "v"}', + '{"k": "v"}', + [1, 10], + [1, 10, 20], + [20], + self.dummy_request, + ) + # Send just the outer closing brace + # tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v"}', + '{"name": "fn", "arguments": {"k": "v"}}', + "}", + [1, 10, 20], + [1, 10, 20, 30], + [30], + self.dummy_request, + ) + # is_complete_json=True, delta="}" -> stripped to "" -> return None + self.assertIsNone(result) + + # --- Lines 316-319: prev_tool_call_arr update branches --- + + def test_streaming_prev_tool_call_arr_append(self): + """Cover lines 318-319: append to prev_tool_call_arr when index doesn't match""" + parser = self._new_parser() + parser.current_tool_id = 1 + parser.current_tool_name_sent = True + parser.streamed_args_for_tool = ["", ""] + parser.prev_tool_call_arr = [{"name": "fn1"}] + # current_tool_id (1) != len(prev_tool_call_arr) - 1 (0), so append + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", + return_value={"name": "fn2"}, + ): + parser.extract_tool_calls_streaming( + "", + '{"name": "fn2"}', + '{"name": "fn2"}', + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertEqual(len(parser.prev_tool_call_arr), 2) + + # --- Lines 323-325: top-level exception handler --- + + def test_streaming_general_exception(self): + """Cover lines 323-325: unexpected exception returns None""" + parser = self._new_parser() + parser.current_tool_name_sent = True + # Force an exception by corrupting internal state + parser.current_tool_id = 0 + parser.streamed_args_for_tool = [""] + parser.prev_tool_call_arr = None # will cause exception on access + result = parser.extract_tool_calls_streaming( + "", + '{"name": "fn"}', + '{"name": "fn"}', + [1], + [1, 10], + [10], + self.dummy_request, + ) + self.assertIsNone(result) + + # ==================== Full streaming simulation ==================== + + def test_streaming_full_flow(self): + """Integration test: simulate a full streaming tool call flow""" + parser = self._new_parser() + req = self.dummy_request + + # Step 1: text before tool call + r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req) + self.assertEqual(r.content, "thinking") + + # Step 2: tool_call start token + r = parser.extract_tool_calls_streaming("thinking", "thinking", "", [], [1], [1], req) + self.assertIsNone(r) + + # Step 3: function name appears + r = parser.extract_tool_calls_streaming( + "thinking", + 'thinking{"name": "search"', + '{"name": "search"', + [1], + [1, 10], + [10], + req, + ) + self.assertIsNotNone(r) + self.assertEqual(r.tool_calls[0].function.name, "search") + + # Step 4: arguments start - delta must appear in regex-extracted arguments portion + r = parser.extract_tool_calls_streaming( + 'thinking{"name": "search"', + 'thinking{"name": "search", "arguments": {"query": "test', + '{"query": "test', + [1, 10], + [1, 10, 20], + [20], + req, + ) + self.assertIsNotNone(r) + + # Step 5: more arguments + r = parser.extract_tool_calls_streaming( + 'thinking{"name": "search", "arguments": {"query": "test', + 'thinking{"name": "search", "arguments": {"query": "test data', + " data", + [1, 10, 20], + [1, 10, 20, 30], + [30], + req, + ) + self.assertIsNotNone(r) + self.assertEqual(r.tool_calls[0].function.arguments, " data") + + def test_streaming_multiple_tool_calls(self): + """Integration test: two tool calls in one response""" + parser = self._new_parser() + req = self.dummy_request + + # First tool call + parser.extract_tool_calls_streaming( + "", + '{"name": "fn1"', + '{"name": "fn1"', + [], + [1, 10], + [1, 10], + req, + ) + self.assertEqual(parser.current_tool_id, 0) + + # Close first tool + parser.extract_tool_calls_streaming( + '{"name": "fn1"', + '{"name": "fn1"}', + "}", + [1, 10], + [1, 10, 2], + [2], + req, + ) + + # Second tool call + r = parser.extract_tool_calls_streaming( + '{"name": "fn1"}', + '{"name": "fn1"}{"name": "fn2"', + '{"name": "fn2"', + [1, 10, 2], + [1, 10, 2, 1, 20], + [1, 20], + req, + ) + self.assertEqual(parser.current_tool_id, 1) + self.assertIsNotNone(r) + self.assertEqual(r.tool_calls[0].function.name, "fn2") + + +if __name__ == "__main__": + unittest.main()