diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py
index 79eb3058b15..4e6cc93cbaa 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py
@@ -246,10 +246,6 @@ def extract_tool_calls_streaming(
if self.valid is not None and not self.valid:
return DeltaMessage(content=delta_text)
- # Skip empty chunks
- if len(delta_text.strip()) == 0:
- return None
-
try:
delta = None
# Use buffer to accumulate delta_text content
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
index 3427a1814fa..ba162b2d516 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
@@ -18,7 +18,6 @@
import re
import uuid
from collections.abc import Sequence
-from typing import Union
import partial_json_parser
@@ -28,6 +27,8 @@ def random_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+from partial_json_parser.core.options import Allow
+
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
@@ -41,40 +42,44 @@ def random_tool_call_id() -> str:
ToolParser,
ToolParserManager,
)
-from fastdeploy.utils import data_processor_logger
+from fastdeploy.utils import data_processor_logger as logger
@ToolParserManager.register_module("ernie-x1")
class ErnieX1ToolParser(ToolParser):
"""
- Tool parser for Ernie model version 4.5.1.
This parser handles tool calls with newline formats.
"""
def __init__(self, tokenizer):
+ """
+ Ernie thinking model format:
+ abc\n\n\n\n\ndef\n\n
+ """
super().__init__(tokenizer)
-
+ self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
- self.current_tool_id: int = -1
- self.current_tool_name_sent: bool = False
- self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
- self.buffer: str = "" # buffer for accumulating unprocessed streaming content
- self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
- self.tool_call_start_token: str = ""
- self.tool_call_end_token: str = ""
+ self.current_tool_id = -1
+ self.streamed_args_for_tool: list[str] = []
+ self.think_end_token = ""
+ self.response_start_token: str = ""
+ self.response_end_token: str = ""
+ self.tool_call_start_token = ""
+ self.tool_call_end_token = ""
- self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
- self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
- if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
- raise RuntimeError(
- "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
- )
+ self.tool_call_regex = re.compile(r"\s*(?P\{.*?\})\s*", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
- "The model tokenizer must be passed to the ToolCallParser constructor during construction."
+ "The model tokenizer must be passed to the ToolParser " "constructor during construction."
)
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ self.response_start_token_id = self.vocab.get(self.response_start_token)
+ self.response_end_token_id = self.vocab.get(self.response_end_token)
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
@@ -88,144 +93,28 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
"""
try:
+ tool_call_json_list = self.tool_call_regex.findall(model_output)
tool_calls = []
-
- # Check for invalid tags before tool calls
- if re.search(r"[\s\S]*?\s*(?=)", model_output):
- data_processor_logger.error("Invalid format: tags found before ")
- return ExtractedToolCallInformation(tools_called=False, content=model_output)
-
- function_call_arr = []
- remaining_text = model_output
-
- while True:
- # Find the next
- tool_call_pos = remaining_text.find("")
- if tool_call_pos == -1:
- break
-
- # Extract content after
- tool_content_start = tool_call_pos + len("")
- tool_content_end = remaining_text.find("", tool_content_start)
-
- tool_json = ""
- if tool_content_end == -1:
- # Processing unclosed tool_call block (truncated case)
- tool_json = remaining_text[tool_content_start:].strip()
- remaining_text = "" # No more content to process
- else:
- # Processing closed block
- tool_json = remaining_text[tool_content_start:tool_content_end].strip()
- remaining_text = remaining_text[tool_content_end + len("") :]
-
- if not tool_json:
- continue
-
- # Process tool_json
- tool_json = tool_json.strip()
- if not tool_json.startswith("{"):
- tool_json = "{" + tool_json
- if not tool_json.endswith("}"):
- tool_json = tool_json + "}"
-
- try:
- # Parsing strategy: First try standard json.loads
- try:
- tool_data = json.loads(tool_json)
-
- if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
- function_call_arr.append(
- {
- "name": tool_data["name"],
- "arguments": tool_data["arguments"],
- "_is_complete": True, # Mark as complete
- }
- )
- continue
- except json.JSONDecodeError:
- pass
-
- # Try partial_json_parser when standard parsing fails
- from partial_json_parser.core.options import Allow
-
- try:
- tool_data = {}
- flags = Allow.ALL & ~Allow.STR
-
- # Parse the name field
- name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
- if name_match:
- tool_data["name"] = name_match.group(1)
-
- # Parse the arguments field
- args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
- if args_match:
- try:
- tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
- except Exception as e:
- data_processor_logger.debug(f"Failed to parse tool arguments: {e}")
- tool_data["arguments"] = None
-
- if isinstance(tool_data, dict):
- function_call_arr.append(
- {
- "name": tool_data.get("name", ""),
- "arguments": tool_data.get("arguments", {}),
- "_is_partial": True, # Mark as partial
- }
- )
- except Exception as e:
- data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
- continue
- except Exception as e:
- data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
- continue
-
- if not function_call_arr:
- data_processor_logger.error("No valid tool calls found")
- return ExtractedToolCallInformation(tools_called=False, content=model_output)
-
- tool_calls = []
- all_complete = True # Initialize as all complete
-
- for tool_call in function_call_arr:
- # Set flags
- is_complete = tool_call.get("_is_complete", False)
- is_partial = tool_call.get("_is_partial", False)
-
- # If any tool call is incomplete or partial, mark all_complete as False
- if not is_complete or is_partial:
- all_complete = False
-
- # Process arguments
- tool_args = tool_call.get("arguments", {})
- if not isinstance(tool_args, dict):
- tool_args = {}
-
- try:
- args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
- except:
- args_str = "{}"
-
+ for tool_call_json in tool_call_json_list:
+ tool_call_dict = json.loads(tool_call_json)
+ args_str = json.dumps(tool_call_dict.get("arguments", {}), ensure_ascii=False)
tool_calls.append(
ToolCall(
type="function",
id=random_tool_call_id(),
function=FunctionCall(
- name=tool_call.get("name", ""),
+ name=tool_call_dict.get("name", ""),
arguments=args_str,
),
)
)
-
- # Only return tools_called=True if all tool calls are complete
return ExtractedToolCallInformation(
- tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
+ tools_called=True,
+ tool_calls=tool_calls,
)
-
- except Exception as e:
- data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
- return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
+ except Exception:
+ logger.warning("Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
def extract_tool_calls_streaming(
self,
@@ -235,114 +124,202 @@ def extract_tool_calls_streaming(
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
- request: dict,
- ) -> Union[DeltaMessage, None]:
-
- if self.tool_call_start_token_id not in current_token_ids:
+ request: ChatCompletionRequest,
+ ) -> DeltaMessage | None:
+ if self.tool_call_start_token not in current_text:
+ logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
- # Skip empty chunks
- if len(delta_text.strip()) == 0:
- return None
try:
- delta = None
- # Use buffer to accumulate delta_text content
- self.buffer += delta_text
-
- # Process the buffer content
- if "" in delta_text:
- self.current_tool_id = (
- max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
+ prev_tool_start_count = previous_text.count(self.tool_call_start_token)
+ prev_tool_end_count = previous_text.count(self.tool_call_end_token)
+ cur_tool_start_count = current_text.count(self.tool_call_start_token)
+ cur_tool_end_count = current_text.count(self.tool_call_end_token)
+ tool_call_portion = None
+ text_portion = None
+
+ if (
+ cur_tool_start_count == cur_tool_end_count
+ and prev_tool_end_count == cur_tool_end_count
+ and self.tool_call_end_token not in delta_text
+ ):
+ logger.debug("Generating text content! skipping tool parsing.")
+ return DeltaMessage(content=delta_text)
+
+ if self.tool_call_end_token in delta_text:
+ logger.debug("tool_call_end_token in delta_text")
+ full_text = current_text + delta_text
+ tool_call_portion = (
+ full_text.split(self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0].rstrip()
)
+ delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
+ text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
+
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
+
+ if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count:
+ if len(delta_token_ids) > 1:
+ tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
+ else:
+ tool_call_portion = None
+ delta = None
+
+ text_portion = None
+
+ self.current_tool_id += 1
self.current_tool_name_sent = False
- if len(self.streamed_args_for_tool) <= self.current_tool_id:
- self.streamed_args_for_tool.append("")
- data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
-
- # 1. Try to parse the name field
- if not self.current_tool_name_sent and '"name"' in self.buffer:
- name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
- if name_match:
- name = name_match.group(1)
- if name:
- delta = DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_id,
- type="function",
- id=random_tool_call_id(),
- function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
- )
- ]
- )
- # Delete the processed name part from the buffer
- self.buffer = self.buffer[name_match.end() :]
- self.current_tool_name_sent = True
- return delta
- # 2. Processing arguments field
- if '"arguments"' in self.buffer:
- args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
- if args_match:
- args_content = args_match.group(1)
- try:
- # Check if arguments field is complete by bracket matching
- if "}}" in args_content:
- matched_pos = -1
- for i, ch in enumerate(delta_text):
- if ch == "{":
- self.bracket_counts["total_l"] += 1
- elif ch == "}":
- self.bracket_counts["total_r"] += 1
-
- if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
- matched_pos = i
- break
-
- if matched_pos >= 0:
- # Clean up bracket counts for next tool call
- truncate_text = delta_text[: matched_pos + 1]
- delta = DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_id,
- function=DeltaFunctionCall(arguments=truncate_text).model_dump(
- exclude_none=True
- ),
- )
- ]
- )
- self.buffer = self.buffer[args_match.end() :]
- return delta
- else:
- # No complete match yet
- return None
- else:
- # Return partial arguments
- for ch in delta_text:
- if ch == "{":
- self.bracket_counts["total_l"] += 1
- elif ch == "}":
- self.bracket_counts["total_r"] += 1
- delta = DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_id,
- function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
- )
- ]
- )
- return delta
- except Exception as e:
- data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
+ self.streamed_args_for_tool.append("")
+ logger.debug("Starting on a new tool %s", self.current_tool_id)
+
+ elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count:
+ tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
+ text_portion = None
+
+ elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count:
+ if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
+ logger.debug("attempting to close tool call, but no tool call")
+ return None
+ diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
+ if diff:
+ if '"}' not in delta_text:
return None
- if "" in self.buffer:
- end_pos = self.buffer.find("")
- self.buffer = self.buffer[end_pos + len("") :]
+ end_loc = delta_text.rindex('"}')
+ diff = delta_text[:end_loc] + '"}'
+ logger.debug(
+ "Finishing tool and found diff that had not " "been streamed yet: %s",
+ diff,
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += diff
+ return DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True),
+ )
+ ]
+ )
- self.streamed_args_for_tool.append("")
+ else:
+ text = delta_text.replace(self.tool_call_start_token, "")
+ text = text.replace(self.tool_call_end_token, "")
+ delta = DeltaMessage(tool_calls=[], content=text)
+ return delta
+
+ try:
+ current_tool_call = (
+ partial_json_parser.loads(tool_call_portion or "{}", flags) if tool_call_portion else None
+ )
+ logger.debug("Parsed tool call %s", current_tool_call)
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug("not enough tokens to parse into JSON yet")
+ return None
+ except json.decoder.JSONDecodeError:
+ logger.debug("unable to parse JSON")
+ return None
+
+ if not self.current_tool_name_sent:
+ if current_tool_call is None:
+ return None
+ function_name: str | None = current_tool_call.get("name")
+ if function_name:
+ self.current_tool_name_sent = True
+ return DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ type="function",
+ id=random_tool_call_id(),
+ function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ else:
+ return None
+
+ if tool_call_portion is None:
+ delta = DeltaMessage(content=delta_text) if text_portion is not None else None
+ return delta
+
+ if len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
+ cur_arguments = current_tool_call.get("arguments")
+
+ if not cur_arguments and not prev_arguments:
+ logger.debug("Skipping text %s - no arguments", delta_text)
+ delta = None
+
+ elif not cur_arguments and prev_arguments:
+ logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.")
+ delta = None
+
+ elif cur_arguments and not prev_arguments:
+ function_name = current_tool_call.get("name")
+ match = re.search(
+ r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
+ tool_call_portion.strip(),
+ re.DOTALL,
+ )
+ if match:
+ cur_arguments_json = match.group(1)
+ else:
+ cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
+
+ logger.debug("finding %s in %s", delta_text, cur_arguments_json)
+
+ if delta_text not in cur_arguments_json:
+ return None
+ args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(delta_text)
+
+ arguments_delta = cur_arguments_json[:args_delta_start_loc]
+ logger.debug("First tokens in arguments received: %s", arguments_delta)
+
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=arguments_delta).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
+
+ elif cur_arguments and prev_arguments:
+ try:
+ json.loads(tool_call_portion)
+ is_complete_json = True
+ except Exception:
+ is_complete_json = False
+
+ if (
+ isinstance(delta_text, str)
+ and len(delta_text.rstrip()) >= 1
+ and delta_text.rstrip()[-1] == "}"
+ and is_complete_json
+ ):
+ delta_text = delta_text.rstrip()[:-1]
+
+ logger.debug("got diff %s", delta_text)
+ if is_complete_json and delta_text.strip() == "":
+ return None
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += delta_text
+
+ if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
+ self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
+ else:
+ self.prev_tool_call_arr.append(current_tool_call)
return delta
- except Exception as e:
- data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
- return None
+ except Exception:
+ logger.warning("Error trying to handle streaming tool call.")
+ return None # do not stream a delta. skip this token ID.
diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py
index c8a24c5707d..5cde01c3933 100644
--- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py
@@ -15,6 +15,7 @@
"""
import unittest
+from unittest.mock import patch
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import (
@@ -22,115 +23,831 @@
)
-class DummyTokenizer:
- """Dummy tokenizer with vocab containing tool_call tokens"""
-
- def __init__(self):
- self.vocab = {"": 1, "": 2}
-
- def get_vocab(self):
- return self.vocab
-
-
class TestErnieX1ToolParser(unittest.TestCase):
def setUp(self):
+ class DummyTokenizer:
+ def __init__(self):
+ self.vocab = {
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ }
+
+ def get_vocab(self):
+ return self.vocab
+
self.tokenizer = DummyTokenizer()
self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer)
self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
- # ---------------- Batch extraction tests ----------------
+ def _new_parser(self):
+ """Create a fresh parser to avoid state pollution between tests."""
+
+ class DummyTokenizer:
+ def __init__(self):
+ self.vocab = {
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ }
+
+ def get_vocab(self):
+ return self.vocab
- def test_extract_tool_calls_complete(self):
- """Test normal extraction of complete tool_call JSON"""
- output = '{"name": "get_weather", "arguments": {"location": "Beijing"}}'
+ return ErnieX1ToolParser(tokenizer=DummyTokenizer())
+
+ # ==================== __init__ tests (lines 60-81) ====================
+
+ def test_init_sets_tokens_and_ids(self):
+ """Cover lines 60-81: verify all token attributes and vocab lookups"""
+ p = self.parser
+ self.assertFalse(p.current_tool_name_sent)
+ self.assertEqual(p.prev_tool_call_arr, [])
+ self.assertEqual(p.current_tool_id, -1)
+ self.assertEqual(p.streamed_args_for_tool, [])
+ self.assertEqual(p.think_end_token, "")
+ self.assertEqual(p.response_start_token, "")
+ self.assertEqual(p.response_end_token, "")
+ self.assertEqual(p.tool_call_start_token, "")
+ self.assertEqual(p.tool_call_end_token, "")
+ self.assertIsNotNone(p.tool_call_regex)
+ self.assertEqual(p.think_end_token_id, 3)
+ self.assertEqual(p.response_start_token_id, 4)
+ self.assertEqual(p.response_end_token_id, 5)
+ self.assertEqual(p.tool_call_start_token_id, 1)
+ self.assertEqual(p.tool_call_end_token_id, 2)
+
+ def test_init_raises_without_tokenizer(self):
+ """Cover lines 72-75: ValueError when tokenizer is falsy"""
+ with self.assertRaises(ValueError):
+ ErnieX1ToolParser(tokenizer=None)
+
+ # ==================== extract_tool_calls tests (lines 96-117) ====================
+
+ def test_extract_tool_calls_single(self):
+ """Cover lines 96-114: single complete tool call"""
+ output = '{"name": "get_weather", "arguments": {"location": "北京"}}'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertTrue(result.tools_called)
+ self.assertEqual(len(result.tool_calls), 1)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+ self.assertIn("北京", result.tool_calls[0].function.arguments)
- def test_extract_tool_calls_no_toolcall(self):
- """Test when no tool_call tags are present"""
- output = "no tool call here"
+ def test_extract_tool_calls_multiple(self):
+ """Cover lines 98-100: multiple tool calls"""
+ output = (
+ '{"name": "get_weather", "arguments": {"location": "北京"}}'
+ '{"name": "get_time", "arguments": {"timezone": "UTC"}}'
+ )
result = self.parser.extract_tool_calls(output, self.dummy_request)
- self.assertFalse(result.tools_called)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(len(result.tool_calls), 2)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+ self.assertEqual(result.tool_calls[1].function.name, "get_time")
- def test_extract_tool_calls_exception(self):
- """Completely broken JSON triggers the exception branch"""
- output = "not json at all{{{"
+ def test_extract_tool_calls_no_arguments(self):
+ """Cover line 100: tool call with no arguments defaults to {}"""
+ output = '{"name": "list_items"}'
result = self.parser.extract_tool_calls(output, self.dummy_request)
- self.assertFalse(result.tools_called)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.arguments, "{}")
- def test_extract_tool_calls_partial_json_parser_failure(self):
- """Test partial_json_parser failure path for arguments (L165-166).
- json.loads fails on malformed JSON, partial_json_parser.loads also fails on deeply broken args.
- Partial result has _is_partial=True so tools_called=False, but tool_calls is populated."""
- # Malformed JSON: valid name but arguments is a bare invalid token
- # that breaks both json.loads and partial_json_parser
- output = '{"name": "test", "arguments": @@@INVALID@@@}'
+ def test_extract_tool_calls_nested_arguments(self):
+ """Cover regex with nested braces in arguments"""
+ output = '{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}'
result = self.parser.extract_tool_calls(output, self.dummy_request)
- # _is_partial=True → tools_called=False, but tool_calls list is populated
- self.assertFalse(result.tools_called)
- self.assertIsNotNone(result.tool_calls)
- self.assertEqual(result.tool_calls[0].function.name, "test")
- # arguments=None → converted to {} → serialized as "{}"
- self.assertEqual(result.tool_calls[0].function.arguments, "{}")
+ self.assertTrue(result.tools_called)
+ self.assertIn("$gt", result.tool_calls[0].function.arguments)
+
+ def test_extract_tool_calls_with_whitespace(self):
+ """Cover regex with whitespace around JSON"""
+ output = ' \n{"name": "fn", "arguments": {}} \n'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "fn")
- def test_partial_json_parser_exception_triggers_debug_log(self):
- """Malformed JSON + partial_json_parser failure exercises L165-166 exactly."""
- # Unclosed string in arguments breaks both json.loads and partial_json_parser
- output = '{"name": "my_tool", "arguments": {"key": "unterminated}'
+ def test_extract_tool_calls_no_match(self):
+ """Cover lines 96, 111-114: no tool_call tags -> tools_called=True with empty list"""
+ output = "just plain text"
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(len(result.tool_calls), 0)
+
+ def test_extract_tool_calls_invalid_json(self):
+ """Cover lines 115-117: malformed JSON triggers exception branch"""
+ output = "{invalid json}"
result = self.parser.extract_tool_calls(output, self.dummy_request)
- # Partial parse → tools_called=False but tool_calls has entries
self.assertFalse(result.tools_called)
+ self.assertEqual(result.content, output)
+
+ def test_extract_tool_calls_exception(self):
+ """Cover lines 115-117: forced exception via mock"""
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads",
+ side_effect=Exception("boom"),
+ ):
+ output = '{"name": "get_weather", "arguments": {}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(result.content, output)
+
+ # ==================== extract_tool_calls_streaming tests ====================
+
+ # --- Line 129-131: no tool_call_start_token in current_text ---
+
+ def test_streaming_no_tool_call_token(self):
+ """Cover lines 129-131: no in current_text returns content delta"""
+ result = self.parser.extract_tool_calls_streaming("", "hello world", "world", [], [], [], self.dummy_request)
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.content, "world")
+ self.assertIsNone(result.tool_calls)
+
+ # --- Lines 141-147: balanced start/end counts, text generation after tool call ---
+
+ def test_streaming_balanced_counts_text_after_tool(self):
+ """Cover lines 134-147: start==end, prev_end==cur_end, end not in delta -> text content"""
+ prev = "{}"
+ cur = "{} some text"
+ delta = " some text"
+ result = self.parser.extract_tool_calls_streaming(prev, cur, delta, [1, 2], [1, 2], [], self.dummy_request)
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.content, delta)
+
+ # --- Lines 149-156: tool_call_end_token in delta_text ---
+
+ def test_streaming_end_token_in_delta(self):
+ """Cover lines 149-156: appears in delta"""
+ parser = self._new_parser()
+ # First, start a tool call
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # Now stream arguments
+ parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn", "arguments": {"k": "v',
+ ', "arguments": {"k": "v',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ # Close with end token in delta
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn", "arguments": {"k": "v',
+ '{"name": "fn", "arguments": {"k": "v"}}',
+ '"}}',
+ [1, 10, 20],
+ [1, 10, 20, 2],
+ [2],
+ self.dummy_request,
+ )
+ # Should handle end token
+ self.assertTrue(result is None or isinstance(result, DeltaMessage))
+
+ # --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) ---
+
+ def test_streaming_new_tool_call_single_token(self):
+ """Cover lines 160-172 (len(delta_token_ids)==1): new tool start with single token"""
+ parser = self._new_parser()
+ result = parser.extract_tool_calls_streaming(
+ "",
+ "",
+ "",
+ [],
+ [1],
+ [1],
+ self.dummy_request,
+ )
+ # tool_call_portion is None, current_tool_call is None, name not sent -> None
+ self.assertIsNone(result)
+ self.assertEqual(parser.current_tool_id, 0)
+ self.assertEqual(len(parser.streamed_args_for_tool), 1)
+
+ def test_streaming_new_tool_call_multi_tokens(self):
+ """Cover lines 160-162 (len(delta_token_ids)>1): new tool start with content"""
+ parser = self._new_parser()
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ self.assertIsNotNone(result)
+ self.assertEqual(result.tool_calls[0].function.name, "fn")
+ self.assertEqual(parser.current_tool_id, 0)
+
+ # --- Lines 174-176: continuing inside existing tool (cur_start > cur_end, same start count) ---
+
+ def test_streaming_continue_tool_call_no_name_yet(self):
+ """Cover lines 174-176, 220-222: partial JSON without name yet"""
+ parser = self._new_parser()
+ # Start tool call
+ parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request)
+ # Continue with partial content, no name parseable yet
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"na',
+ '{"na',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_continue_tool_call_with_name(self):
+ """Cover lines 174-176, 223-235: name becomes available"""
+ parser = self._new_parser()
+ # Start tool call
+ parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request)
+ # Name appears
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "get_weather"',
+ '{"name": "get_weather"',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertIsNotNone(result)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+ self.assertTrue(parser.current_tool_name_sent)
+
+ # --- Lines 236-237: name not sent and function_name is None ---
+
+ def test_streaming_no_function_name(self):
+ """Cover lines 236-237: parsed JSON has no 'name' field"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request)
+ # Send JSON without name field
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"arguments": {"k": "v"}}',
+ '{"arguments": {"k": "v"}}',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ # --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) ---
+
+ def test_streaming_close_no_prev_tool_call(self):
+ """Cover lines 178-181: close branch with empty prev_tool_call_arr"""
+ parser = self._new_parser()
+ parser.prev_tool_call_arr = []
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ result = parser.extract_tool_calls_streaming(
+ '{"name":"fn","arguments":{"k":"v"}}',
+ '{"name":"fn","arguments":{"k":"v"}}',
+ "",
+ [1, 10],
+ [1, 10, 2],
+ [2],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_close_with_remaining_diff(self):
+ """Cover lines 182-200: close with arguments diff that hasn't been streamed"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
+ result = parser.extract_tool_calls_streaming(
+ '{"name":"fn","arguments":{"k":"v"}}',
+ '{"name":"fn","arguments":{"k":"v"}}',
+ '"}}',
+ [1, 10],
+ [1, 10, 2],
+ [2],
+ self.dummy_request,
+ )
+ self.assertIsNotNone(result)
self.assertIsNotNone(result.tool_calls)
- self.assertEqual(result.tool_calls[0].function.name, "my_tool")
- # ---------------- Streaming extraction tests ----------------
+ def test_streaming_close_with_diff_no_end_marker(self):
+ """Cover lines 184-185: close with arguments but no '"}' in delta_text"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
+ # Simulate end token in delta but without '"}' pattern
+ # We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta
+ # so that we enter the elif at 178
+ result = parser.extract_tool_calls_streaming(
+ '{"name":"fn","arguments":{"k":"v"}}',
+ '{"name":"fn","arguments":{"k":"v"}} text',
+ " text",
+ [1, 10, 2],
+ [1, 10, 2, 30],
+ [30],
+ self.dummy_request,
+ )
+ # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147)
+ self.assertIsInstance(result, DeltaMessage)
+
+ def test_streaming_close_no_arguments(self):
+ """Cover lines 182-183: close branch where prev arguments is None/empty"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{"name": "fn"}] # no arguments key
+ result = parser.extract_tool_calls_streaming(
+ '{"name":"fn"}',
+ '{"name":"fn"}',
+ "}",
+ [1, 10],
+ [1, 10, 2],
+ [2],
+ self.dummy_request,
+ )
+ # diff is None (no arguments), so falls through to partial_json_parser
+ self.assertTrue(result is None or isinstance(result, DeltaMessage))
+
+ # --- Lines 202-206: else branch (cur_start < cur_end, edge case) ---
+
+ def test_streaming_else_branch(self):
+ """Cover lines 202-206: fall-through else branch"""
+ parser = self._new_parser()
+ parser.current_tool_name_sent = True
+ # Construct a scenario where cur_start < cur_end (more end tags than start)
+ prev = ""
+ cur = ""
+ delta = ""
+ result = parser.extract_tool_calls_streaming(prev, cur, delta, [1], [1, 2, 2], [2], self.dummy_request)
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.tool_calls, [])
+
+ # --- Lines 208-218: partial_json_parser errors ---
+
+ def test_streaming_malformed_json(self):
+ """Cover lines 213-215: MalformedJSON from partial parser"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request)
+ # Feed badly formed content
+ result = parser.extract_tool_calls_streaming(
+ "",
+ "{{{",
+ "{{{",
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_json_decode_error(self):
+ """Cover lines 216-218: JSONDecodeError from partial parser"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request)
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
+ side_effect=ValueError("bad json"),
+ ):
+ result = parser.extract_tool_calls_streaming(
+ "",
+ "bad",
+ "bad",
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ # --- Lines 239-241: tool_call_portion is None after name sent ---
+
+ def test_streaming_tool_portion_none_with_text(self):
+ """Cover lines 239-241: tool_call_portion is None, text_portion is not None"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{}]
+ # Force tool_call_portion = None and text_portion = not None
+ # This happens when end_token is in delta (sets text_portion) but new tool start
+ # overrides tool_call_portion to None with single token
+ # Simulate: new tool start with single token AND end token in delta
+ # Actually, the simplest path: end token in delta sets text_portion, then new tool start
+ # sets tool_call_portion = None
+ # Let's use a different approach - directly test via the continuing branch
+ # where tool_call_portion remains None from the end_token path
+ result = parser.extract_tool_calls_streaming(
+ '{"name":"fn"}',
+ '{"name":"fn"}',
+ "",
+ [1, 10],
+ [1, 10, 2, 1],
+ [2, 1],
+ self.dummy_request,
+ )
+ self.assertTrue(result is None or isinstance(result, DeltaMessage))
+
+ # --- Lines 243-244: append to prev_tool_call_arr ---
+
+ def test_streaming_first_arguments_with_regex_match(self):
+ """Cover lines 243-244, 257-286: first arguments appear, regex matches"""
+ parser = self._new_parser()
+ # Start tool call and send name
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "get_weather"',
+ '{"name": "get_weather"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # Now stream arguments (first time)
+ # Key must be complete (closing quote) so partial_json_parser returns truthy arguments.
+ # delta must be a substring of the regex-extracted arguments portion (after "arguments":).
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "get_weather"',
+ '{"name": "get_weather", "arguments": {"location": "bei',
+ '"bei',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ self.assertIsNotNone(result)
+ self.assertIsNotNone(result.tool_calls)
+
+ def test_streaming_first_arguments_no_regex_match(self):
+ """Cover lines 266-267: regex doesn't match, fallback to json.dumps"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{}]
+ # Use tool_call_portion where key order differs so regex won't match
+ # (regex expects {"name":... at the start, but here "extra" comes first)
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
+ return_value={"name": "fn", "extra": True, "arguments": {"k": "v"}},
+ ):
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"extra": true, "name": "fn", "arguments": {"k": "v"}}',
+ '"v"}',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ # regex fails on {"extra":... format, falls back to json.dumps
+ # delta '"v"}' is in json.dumps({"k": "v"}) = '{"k": "v"}'
+ self.assertIsNotNone(result)
+ self.assertIsNotNone(result.tool_calls)
+
+ def test_streaming_first_arguments_delta_not_in_json(self):
+ """Cover lines 271-272: delta_text not found in cur_arguments_json"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # Delta text that doesn't appear in the arguments JSON
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn", "arguments": {"k": "v"}}',
+ "ZZZZZ",
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ # --- Lines 249-251: no cur_arguments and no prev_arguments ---
- def test_streaming_no_toolcall(self):
- """Streaming extraction returns normal DeltaMessage when no toolcall tag"""
- result = self.parser.extract_tool_calls_streaming(
- "", "abc", "abc", [], [], [], self.dummy_request.model_dump()
+ def test_streaming_no_arguments_at_all(self):
+ """Cover lines 249-251: both cur and prev arguments are empty/None"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
)
+ # Continue with name only, no arguments
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn"}',
+ "}",
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ # prev_arguments=None, cur_arguments=None -> delta=None
+ # then prev_tool_call_arr updated and returns delta (which is None)
+ self.assertIsNone(result)
+
+ # --- Lines 253-255: cur_arguments reset (impossible branch) ---
+
+ def test_streaming_arguments_reset_mid_call(self):
+ """Cover lines 253-255: prev has arguments but cur doesn't (impossible case)"""
+ parser = self._new_parser()
+ parser.current_tool_id = 0
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
+ # Feed content where cur has no arguments but prev does
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
+ return_value={"name": "fn"},
+ ):
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn", "arguments": {"k": "v"',
+ '{"name": "fn", "arguments": {"k": "v"}',
+ '"}',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ self.assertIsNone(result)
+
+ # --- Lines 288-314: cur_arguments and prev_arguments both present ---
+
+ def test_streaming_incremental_arguments_incomplete(self):
+ """Cover lines 288-314: both prev and cur have arguments, JSON incomplete"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # First arguments - delta must appear in regex-extracted arguments portion
+ parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn", "arguments": {"k": "v',
+ '{"k": "v',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ # More argument tokens (both prev and cur have arguments now)
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn", "arguments": {"k": "v',
+ '{"name": "fn", "arguments": {"k": "val',
+ "al",
+ [1, 10, 20],
+ [1, 10, 20, 30],
+ [30],
+ self.dummy_request,
+ )
+ self.assertIsNotNone(result)
+ self.assertEqual(result.tool_calls[0].function.arguments, "al")
+
+ def test_streaming_incremental_arguments_complete_json(self):
+ """Cover lines 289-305: complete JSON with trailing }"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # First arguments - delta must appear in regex-extracted arguments portion
+ parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn", "arguments": {"k": "v',
+ '{"k": "v',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
+ )
+ # Complete with closing braces - both prev and cur have arguments
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn", "arguments": {"k": "v',
+ '{"name": "fn", "arguments": {"k": "v"}}',
+ '"}}',
+ [1, 10, 20],
+ [1, 10, 20, 30],
+ [30],
+ self.dummy_request,
+ )
+ # is_complete_json=True, delta ends with }, should strip trailing }
+ # After strip: '"' which is not empty, so returns DeltaMessage
+ self.assertIsNotNone(result)
self.assertIsInstance(result, DeltaMessage)
- self.assertEqual(result.content, "abc")
- def test_streaming_skip_empty_chunk(self):
- """Streaming extraction skips empty chunks"""
- result = self.parser.extract_tool_calls_streaming(
- "", "", " ", [], [1], [1], self.dummy_request.model_dump()
+ def test_streaming_incremental_arguments_complete_empty_delta(self):
+ """Cover lines 304-305: complete JSON where delta becomes empty after strip"""
+ parser = self._new_parser()
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"',
+ '{"name": "fn"',
+ [],
+ [1, 10],
+ [1, 10],
+ self.dummy_request,
+ )
+ # First arguments with proper delta
+ parser.extract_tool_calls_streaming(
+ '{"name": "fn"',
+ '{"name": "fn", "arguments": {"k": "v"}',
+ '{"k": "v"}',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ self.dummy_request,
)
+ # Send just the outer closing brace
+ # tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None
+ result = parser.extract_tool_calls_streaming(
+ '{"name": "fn", "arguments": {"k": "v"}',
+ '{"name": "fn", "arguments": {"k": "v"}}',
+ "}",
+ [1, 10, 20],
+ [1, 10, 20, 30],
+ [30],
+ self.dummy_request,
+ )
+ # is_complete_json=True, delta="}" -> stripped to "" -> return None
self.assertIsNone(result)
- def test_streaming_new_toolcall_and_name(self):
- """Streaming extraction detects new toolcall and extracts name"""
- delta = self.parser.extract_tool_calls_streaming(
- "", "", '{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
+ # --- Lines 316-319: prev_tool_call_arr update branches ---
+
+ def test_streaming_prev_tool_call_arr_append(self):
+ """Cover lines 318-319: append to prev_tool_call_arr when index doesn't match"""
+ parser = self._new_parser()
+ parser.current_tool_id = 1
+ parser.current_tool_name_sent = True
+ parser.streamed_args_for_tool = ["", ""]
+ parser.prev_tool_call_arr = [{"name": "fn1"}]
+ # current_tool_id (1) != len(prev_tool_call_arr) - 1 (0), so append
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
+ return_value={"name": "fn2"},
+ ):
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn2"}',
+ '{"name": "fn2"}',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
+ )
+ self.assertEqual(len(parser.prev_tool_call_arr), 2)
+
+ # --- Lines 323-325: top-level exception handler ---
+
+ def test_streaming_general_exception(self):
+ """Cover lines 323-325: unexpected exception returns None"""
+ parser = self._new_parser()
+ parser.current_tool_name_sent = True
+ # Force an exception by corrupting internal state
+ parser.current_tool_id = 0
+ parser.streamed_args_for_tool = [""]
+ parser.prev_tool_call_arr = None # will cause exception on access
+ result = parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn"}',
+ '{"name": "fn"}',
+ [1],
+ [1, 10],
+ [10],
+ self.dummy_request,
)
- self.assertIsNotNone(delta)
- self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
+ self.assertIsNone(result)
+
+ # ==================== Full streaming simulation ====================
- def test_streaming_partial_arguments(self):
- """Streaming extraction yields partial arguments deltas"""
- text = '"arguments": {"location":'
- delta = self.parser.extract_tool_calls_streaming(
- "", "" + text, text, [], [1], [1], self.dummy_request.model_dump()
+ def test_streaming_full_flow(self):
+ """Integration test: simulate a full streaming tool call flow"""
+ parser = self._new_parser()
+ req = self.dummy_request
+
+ # Step 1: text before tool call
+ r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req)
+ self.assertEqual(r.content, "thinking")
+
+ # Step 2: tool_call start token
+ r = parser.extract_tool_calls_streaming("thinking", "thinking", "", [], [1], [1], req)
+ self.assertIsNone(r)
+
+ # Step 3: function name appears
+ r = parser.extract_tool_calls_streaming(
+ "thinking",
+ 'thinking{"name": "search"',
+ '{"name": "search"',
+ [1],
+ [1, 10],
+ [10],
+ req,
)
- self.assertIsInstance(delta, DeltaMessage)
- self.assertIn("arguments", delta.tool_calls[0].function.arguments)
+ self.assertIsNotNone(r)
+ self.assertEqual(r.tool_calls[0].function.name, "search")
- def test_streaming_complete_arguments_and_end(self):
- """Streaming extraction completes arguments with brackets matched and closes tool_call"""
- text = '"arguments": {"location": "Beijing"}}'
- delta = self.parser.extract_tool_calls_streaming(
- "", "" + text, text, [], [1], [1], self.dummy_request.model_dump()
+ # Step 4: arguments start - delta must appear in regex-extracted arguments portion
+ r = parser.extract_tool_calls_streaming(
+ 'thinking{"name": "search"',
+ 'thinking{"name": "search", "arguments": {"query": "test',
+ '{"query": "test',
+ [1, 10],
+ [1, 10, 20],
+ [20],
+ req,
)
- self.assertIsInstance(delta, DeltaMessage)
- # Also simulate closing tag
- end_delta = self.parser.extract_tool_calls_streaming(
- "", "", "", [], [2], [2], self.dummy_request.model_dump()
+ self.assertIsNotNone(r)
+
+ # Step 5: more arguments
+ r = parser.extract_tool_calls_streaming(
+ 'thinking{"name": "search", "arguments": {"query": "test',
+ 'thinking{"name": "search", "arguments": {"query": "test data',
+ " data",
+ [1, 10, 20],
+ [1, 10, 20, 30],
+ [30],
+ req,
+ )
+ self.assertIsNotNone(r)
+ self.assertEqual(r.tool_calls[0].function.arguments, " data")
+
+ def test_streaming_multiple_tool_calls(self):
+ """Integration test: two tool calls in one response"""
+ parser = self._new_parser()
+ req = self.dummy_request
+
+ # First tool call
+ parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "fn1"',
+ '{"name": "fn1"',
+ [],
+ [1, 10],
+ [1, 10],
+ req,
+ )
+ self.assertEqual(parser.current_tool_id, 0)
+
+ # Close first tool
+ parser.extract_tool_calls_streaming(
+ '{"name": "fn1"',
+ '{"name": "fn1"}',
+ "}",
+ [1, 10],
+ [1, 10, 2],
+ [2],
+ req,
+ )
+
+ # Second tool call
+ r = parser.extract_tool_calls_streaming(
+ '{"name": "fn1"}',
+ '{"name": "fn1"}{"name": "fn2"',
+ '{"name": "fn2"',
+ [1, 10, 2],
+ [1, 10, 2, 1, 20],
+ [1, 20],
+ req,
)
- self.assertIsNotNone(end_delta)
- self.assertEqual(end_delta.content, "")
+ self.assertEqual(parser.current_tool_id, 1)
+ self.assertIsNotNone(r)
+ self.assertEqual(r.tool_calls[0].function.name, "fn2")
if __name__ == "__main__":