diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py
index d3f27122dba..494d2248380 100644
--- a/fastdeploy/engine/common_engine.py
+++ b/fastdeploy/engine/common_engine.py
@@ -1909,7 +1909,11 @@ def _send_error_response(self, request_id, error_msg, error_code: int = 500, wor
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
if envs.FD_ENABLE_RETURN_TEXT:
- delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
+ delta_text, previous_token_ids, _ = self.data_processor.ids2tokens(token_ids, req_id)
+ # Reconstruct the post-extend cumulative list from the pre-delta
+ # snapshot + this call's input — ``ids2tokens`` only returns the
+ # snapshot to keep its return values aliasing-free.
+ cum_tokens = previous_token_ids + list(token_ids)
if delta_text != "":
prefix_offset = self.data_processor.decode_status[req_id][0]
read_offset = self.data_processor.decode_status[req_id][1]
diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py
index ffaaf0f4aa5..2cfef290201 100644
--- a/fastdeploy/entrypoints/openai/response_processors.py
+++ b/fastdeploy/entrypoints/openai/response_processors.py
@@ -72,7 +72,9 @@ def accumulate_token_ids(self, request_output):
else:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
- async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output, request):
+ async def process_response_chat(
+ self, request_outputs, stream, include_stop_str_in_output, request, prompt_tokens=None
+ ):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
@@ -101,6 +103,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
audio_tokens=all_audio_tokens,
tts=tts,
request=request,
+ prompt_tokens=prompt_tokens,
)
else:
response = self.data_processor.process_response_dict(
@@ -110,6 +113,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
audio_tokens=all_audio_tokens,
tts=tts,
request=request,
+ prompt_tokens=prompt_tokens,
)
yield response
elif decode_type == 2: # audio
@@ -128,6 +132,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
else:
response = self.data_processor.process_response_dict(
@@ -135,6 +140,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
yield response
elif stream:
@@ -168,6 +174,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
else:
self.data_processor.process_response_dict(
@@ -175,6 +182,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
request_output["outputs"]["multipart"] = [text]
@@ -197,6 +205,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=False,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
else:
self.data_processor.process_response_dict(
@@ -204,6 +213,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
multipart.append(text)
diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py
index d6429521f05..25b77220d27 100644
--- a/fastdeploy/entrypoints/openai/serving_chat.py
+++ b/fastdeploy/entrypoints/openai/serving_chat.py
@@ -317,6 +317,7 @@ async def chat_completion_stream_generator(
stream=True,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
async for res in generator:
@@ -650,6 +651,7 @@ async def chat_completion_full_generator(
stream=False,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
+ prompt_tokens=prompt_tokens,
)
async for data in generator:
idx = get_choice_index(data["request_id"])
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
index 906483f445a..461f702cd1b 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
@@ -34,6 +34,15 @@ class ToolParser:
derived classes.
"""
+ # Subclasses should override these with the literal tool-call sentinel
+ # tokens they recognize (e.g. ``""`` / ``""``).
+ # Used by :meth:`detect_tool_prefix` to support forced tool-call prompt
+ # prefix injection (named-tool ``tool_choice`` or
+ # ``chat_template_kwargs.options.tool_choice.mode == "force"``). Empty
+ # defaults make the detection a no-op for parsers that have not opted in.
+ tool_call_start_token: str = ""
+ tool_call_end_token: str = ""
+
def __init__(self, tokenizer):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
@@ -43,6 +52,16 @@ def __init__(self, tokenizer):
self.model_tokenizer = tokenizer
+ # Per-request tool-prefix state populated by the serving layer when
+ # the chat template injects a forced tool-call prefix into the prompt.
+ self._tool_prefix: str = ""
+ self._tool_prefix_token_ids: list[int] = []
+ # Set after the prefix is computed once for this request.
+ self._tool_prefix_computed: bool = False
+ # Set after the prefix has been spliced into the streaming delta
+ # (only the first chunk needs it).
+ self._tool_prefix_injected_to_delta: bool = False
+
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
@@ -55,6 +74,36 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques
"""
return request
+ def detect_tool_prefix(self, prompt: str) -> str:
+ """Detect the tool-call prefix injected at the tail of the rendered
+ prompt by a forced ``tool_choice``.
+
+ Finds the **last** :attr:`tool_call_start_token` in ``prompt`` that is
+ not closed by a later :attr:`tool_call_end_token` and reaches the
+ prompt end (modulo trailing whitespace). Returns ``""`` otherwise.
+ Subclasses with non-paired tag formats may override.
+ """
+ start = self.tool_call_start_token
+ if not start or not prompt:
+ return ""
+
+ last_start = prompt.rfind(start)
+ if last_start == -1:
+ return ""
+
+ end = self.tool_call_end_token
+ if end and prompt.find(end, last_start + len(start)) != -1:
+ # The last start token is closed — this is a historical, completed
+ # tool-call (e.g. from a previous assistant turn), not an injected
+ # forced prefix.
+ return ""
+
+ # By construction, ``prompt[last_start:]`` reaches the end of the
+ # prompt. We treat the whole tail as the injected prefix. Subclasses
+ # whose chat templates place additional content after the prefix can
+ # override this method to apply stricter validation.
+ return prompt[last_start:]
+
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
diff --git a/fastdeploy/input/base_processor.py b/fastdeploy/input/base_processor.py
index c65e0c42cf4..21c30fb19d7 100644
--- a/fastdeploy/input/base_processor.py
+++ b/fastdeploy/input/base_processor.py
@@ -138,6 +138,28 @@ def text2ids(self, text, max_model_len=None, **kwargs):
)
return tokens["input_ids"][0]
+ def _text_to_token_ids(self, text: str) -> list:
+ """Encode ``text`` to a ``list[int]``, shared by :meth:`messages2ids`
+ and :meth:`_prepare_tool_prefix`.
+
+ ``ernie4_5`` tokenizer hangs on long inputs via ``.encode()``, so it
+ goes through ``tokenize`` + ``convert_tokens_to_ids``. Other tokenizers
+ use ``.encode()`` and the result is normalized to a plain list.
+ """
+ if self.tokenizer_type == "ernie4_5":
+ # NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
+ return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
+ if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
+ token_ids = token_ids["input_ids"]
+ if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
+ token_ids = token_ids[0]
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ if not isinstance(token_ids, list):
+ token_ids = list(token_ids)
+ return token_ids
+
def messages2ids(self, request, **kwargs):
"""Convert a chat-template request into a token-ID list.
@@ -159,19 +181,7 @@ def messages2ids(self, request, **kwargs):
)
request["prompt_tokens"] = spliced_message
req_id = request.get("request_id", None) if isinstance(request, dict) else None
- if self.tokenizer_type == "ernie4_5":
- # NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
- token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message))
- else:
- token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False)
- if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
- token_ids = token_ids["input_ids"]
- if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
- token_ids = token_ids[0]
- if hasattr(token_ids, "tolist"):
- token_ids = token_ids.tolist()
- if not isinstance(token_ids, list):
- token_ids = list(token_ids)
+ token_ids = self._text_to_token_ids(spliced_message)
log_request(
level=1,
message="req_id:{req_id}, token_ids: {token_ids}",
@@ -204,9 +214,16 @@ def ids2tokens(self, token_id, task_id):
Returns:
(delta_text, previous_token_ids, previous_texts)
- Both the HF and the PaddleFormers/ERNIE tokeniser paths return the
- same tuple shape. The HF path sets ``previous_token_ids`` to ``[]``
- since it does not expose per-token ids during batch-decode.
+ ``previous_token_ids`` and ``previous_texts`` are **snapshots of the
+ accumulated state BEFORE this call's tokens were appended** —
+ symmetric pre-delta views of what the caller had decoded so far.
+ Both are owned by the caller (no aliasing of internal state).
+
+ Callers that need the post-extend cumulative list should reconstruct
+ it locally via ``previous_token_ids + token_id``.
+
+ The HF path returns ``[]`` for ``previous_token_ids`` since it does
+ not expose per-token ids during batch-decode.
"""
if envs.FD_USE_HF_TOKENIZER:
if task_id not in self.decode_status:
@@ -225,7 +242,9 @@ def ids2tokens(self, token_id, task_id):
status[2] = decode_str[0]
else:
new_str = ""
- # Return consistent three-tuple; previous_token_ids not available.
+ # NOTE: HF path historically returns the post-delta full string
+ # here, inconsistent with the non-HF branch (which returns the
+ # pre-delta snapshot). Preserved as-is to avoid behavior change.
return new_str, [], status[2]
else:
if task_id not in self.decode_status:
@@ -233,12 +252,15 @@ def ids2tokens(self, token_id, task_id):
self.decode_status[task_id] = [0, 0, [], ""]
status = self.decode_status[task_id]
previous_texts = status[3]
+ # Snapshot BEFORE extend so the returned list is owned by the
+ # caller and symmetric with ``previous_texts``.
+ previous_token_ids = list(status[2])
status[2].extend(token_id)
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1])
status[0] = prefix_offset
status[1] = read_offset
status[3] += decode_str
- return decode_str, status[2], previous_texts
+ return decode_str, previous_token_ids, previous_texts
# ------------------------------------------------------------------
# Response processing
@@ -266,6 +288,37 @@ def process_response_dict(self, response_dict, **kwargs):
else:
return self.process_response_dict_normal(response_dict, **kwargs)
+ def _prepare_tool_prefix(self, tool_parser, prompt_tokens):
+ """Detect and cache on ``tool_parser`` the tool-call prefix that the
+ chat template injected at the tail of ``prompt_tokens`` (the rendered
+ prompt string from the serving layer). Computed once per parser
+ instance via the parser's :meth:`ToolParser.detect_tool_prefix`.
+ """
+ if tool_parser._tool_prefix_computed:
+ return
+ tool_parser._tool_prefix_computed = True
+ tool_parser._tool_prefix = ""
+ tool_parser._tool_prefix_token_ids = []
+ if not prompt_tokens or not isinstance(prompt_tokens, str):
+ return
+ try:
+ prefix = tool_parser.detect_tool_prefix(prompt_tokens) or ""
+ except Exception:
+ data_processor_logger.exception("detect_tool_prefix failed; falling back to empty prefix")
+ return
+ tool_parser._tool_prefix = prefix
+ if not prefix:
+ return
+ # Encode the prefix into token ids so the streaming path can also
+ # splice ``previous/current/delta_token_ids`` — some parsers gate on
+ # ``tool_call_start_token_id in current_token_ids`` rather than on
+ # text (e.g. ``Ernie45VLThinkingToolParser``).
+ try:
+ tool_parser._tool_prefix_token_ids = self._text_to_token_ids(prefix)
+ except Exception:
+ data_processor_logger.exception("encode tool prefix to token ids failed; token-id splice disabled")
+ tool_parser._tool_prefix_token_ids = []
+
def process_response_dict_normal(self, response_dict, **kwargs):
"""Accumulate tokens and build the full completion text (non-streaming)."""
token_ids = response_dict["outputs"]["token_ids"]
@@ -300,7 +353,11 @@ def process_response_dict_normal(self, response_dict, **kwargs):
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
- tool_call_info = tool_parser.extract_tool_calls(full_text, request)
+ parser_input = full_text
+ self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"))
+ if tool_parser._tool_prefix:
+ parser_input = tool_parser._tool_prefix + full_text
+ tool_call_info = tool_parser.extract_tool_calls(parser_input, request)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
@@ -354,13 +411,38 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
if req_id not in self.tool_parser_dict:
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parser_dict[req_id]
+ stream_previous = previous_texts
+ stream_current = previous_texts + delta_text
+ stream_delta = delta_text
+ stream_previous_token_ids = previous_token_ids
+ stream_current_token_ids = previous_token_ids + token_ids
+ stream_delta_token_ids = token_ids
+ self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"))
+ prefix = tool_parser._tool_prefix
+ prefix_ids = tool_parser._tool_prefix_token_ids
+ # Splice the injected prefix back into both text and token-id
+ # streaming args so parsers that gate on either form (e.g.
+ # ``Ernie45VLThinkingToolParser`` checks
+ # ``tool_call_start_token_id in current_token_ids``) work
+ # unchanged. ``delta_*`` only spliced on the first call.
+ if prefix:
+ stream_previous = prefix + stream_previous
+ stream_current = prefix + stream_current
+ if prefix_ids:
+ stream_previous_token_ids = list(prefix_ids) + list(stream_previous_token_ids)
+ stream_current_token_ids = list(prefix_ids) + list(stream_current_token_ids)
+ if not tool_parser._tool_prefix_injected_to_delta:
+ stream_delta = prefix + stream_delta
+ if prefix_ids:
+ stream_delta_token_ids = list(prefix_ids) + list(stream_delta_token_ids)
+ tool_parser._tool_prefix_injected_to_delta = True
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
- previous_texts,
- previous_texts + delta_text,
- delta_text,
- previous_token_ids,
- previous_token_ids + token_ids,
- token_ids,
+ stream_previous,
+ stream_current,
+ stream_delta,
+ stream_previous_token_ids,
+ stream_current_token_ids,
+ stream_delta_token_ids,
request,
)
if tool_call_delta_message:
diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py
index a3487133bcb..3f05797a37a 100644
--- a/tests/engine/test_common_engine.py
+++ b/tests/engine/test_common_engine.py
@@ -752,7 +752,9 @@ def __init__(self):
self.decode_status = {"rid": (0, 2)}
def ids2tokens(self, token_ids, req_id):
- return "hi", [101, 102], None
+ # previous_token_ids snapshot is empty (first call); engine
+ # reconstructs cum = previous + input = [101, 102].
+ return "hi", [], None
eng.data_processor = DummyProcessor()
@@ -782,7 +784,8 @@ def __init__(self):
self.decode_status = {"rid": (0, 1)}
def ids2tokens(self, token_ids, req_id):
- return "", [7], None
+ # previous snapshot is empty; cum becomes [7].
+ return "", [], None
eng.data_processor = DummyProcessor()
@@ -1975,7 +1978,8 @@ def __init__(self):
self.decode_status = {"rid": (0, 2)}
def ids2tokens(self, token_ids, req_id):
- return "hi", [1, 2], None
+ # previous snapshot empty; cum = [] + [1, 2] = [1, 2].
+ return "hi", [], None
eng.data_processor = DummyProcessor()
@@ -3453,7 +3457,8 @@ def __init__(self):
self.decode_status = {"tok-req": (1, 3)}
def ids2tokens(self, token_ids, req_id):
- return "hello", [10, 20, 30], None
+ # previous snapshot empty; cum = [] + [10, 20, 30].
+ return "hello", [], None
eng.data_processor = DummyProcessor()
diff --git a/tests/entrypoints/openai/test_finish_reason.py b/tests/entrypoints/openai/test_finish_reason.py
index 067b80ca0e5..74ce54e21cd 100644
--- a/tests/entrypoints/openai/test_finish_reason.py
+++ b/tests/entrypoints/openai/test_finish_reason.py
@@ -262,7 +262,9 @@ async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = True
- async def mock_process_response_chat_async(response, stream, include_stop_str_in_output, request=None):
+ async def mock_process_response_chat_async(
+ response, stream, include_stop_str_in_output, request=None, prompt_tokens=None
+ ):
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_async
@@ -445,7 +447,9 @@ async def test_chat_stream_max_tokens(self, mock_api_logger, mock_processor_clas
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = False
- async def mock_process_response_chat_async(response, stream, include_stop_str_in_output, request=None):
+ async def mock_process_response_chat_async(
+ response, stream, include_stop_str_in_output, request=None, prompt_tokens=None
+ ):
if isinstance(response, list):
for res in response:
yield res
diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py
index 63db437cc5d..c2efcdd03a0 100644
--- a/tests/entrypoints/openai/test_max_streaming_tokens.py
+++ b/tests/entrypoints/openai/test_max_streaming_tokens.py
@@ -222,7 +222,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class
mock_processor_instance = Mock()
- async def mock_process_response_chat_single(response, stream, include_stop_str_in_output, request=None):
+ async def mock_process_response_chat_single(
+ response, stream, include_stop_str_in_output, request=None, prompt_tokens=None
+ ):
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_single
@@ -639,7 +641,9 @@ async def test_chat_stream_usage_fields(self, mock_response_processor, api_serve
mock_processor_instance = Mock()
- async def mock_process_response_chat(response, stream, include_stop_str_in_output, request=None):
+ async def mock_process_response_chat(
+ response, stream, include_stop_str_in_output, request=None, prompt_tokens=None
+ ):
delta_msg_mock = Mock()
delta_msg_mock.content = response["outputs"]["text"]
if response["outputs"]["text"] == "a":
diff --git a/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py
new file mode 100644
index 00000000000..d8fd3acec0f
--- /dev/null
+++ b/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py
@@ -0,0 +1,99 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import unittest
+
+from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
+
+
+class _DummyTokenizer:
+ def get_vocab(self):
+ return {}
+
+
+class _PairedTagParser(ToolParser):
+ """A concrete parser declaring paired sentinel tokens for testing."""
+
+ tool_call_start_token = ""
+ tool_call_end_token = ""
+
+
+class _NoSentinelParser(ToolParser):
+ """A parser that did not opt in to prefix detection."""
+
+
+class TestDetectToolPrefix(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = _DummyTokenizer()
+ self.parser = _PairedTagParser(self.tokenizer)
+
+ def test_initial_state(self):
+ self.assertEqual(self.parser._tool_prefix, "")
+ self.assertFalse(self.parser._tool_prefix_computed)
+ self.assertFalse(self.parser._tool_prefix_injected_to_delta)
+
+ def test_empty_prompt_returns_empty(self):
+ self.assertEqual(self.parser.detect_tool_prefix(""), "")
+
+ def test_no_start_token_returns_empty(self):
+ self.assertEqual(
+ self.parser.detect_tool_prefix("user: hello\nassistant: hi"),
+ "",
+ )
+
+ def test_parser_without_sentinel_returns_empty(self):
+ parser = _NoSentinelParser(self.tokenizer)
+ self.assertEqual(
+ parser.detect_tool_prefix("anything here"),
+ "",
+ )
+
+ def test_trailing_start_token_only(self):
+ prompt = "user: q\n"
+ self.assertEqual(self.parser.detect_tool_prefix(prompt), "")
+
+ def test_trailing_start_with_invoke_prefix(self):
+ prompt = "history\n{...}\nuser: next"
+ self.assertEqual(self.parser.detect_tool_prefix(prompt), "")
+
+ def test_history_closed_plus_new_injected_prefix(self):
+ prompt = "{a:1}\n{a:1}\n" "{b:2}\n" "assistant: done"
+ self.assertEqual(self.parser.detect_tool_prefix(prompt), "")
+
+ def test_trailing_whitespace_after_start(self):
+ prompt = "history\n "
+ self.assertEqual(
+ self.parser.detect_tool_prefix(prompt),
+ " ",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py
index ebb4c9ff127..940fe51ec46 100644
--- a/tests/input/test_text_processor.py
+++ b/tests/input/test_text_processor.py
@@ -130,6 +130,8 @@ def _create_dummy_modules():
info=lambda *args, **kwargs: None,
warning=lambda *args, **kwargs: None,
debug=lambda *args, **kwargs: None,
+ exception=lambda *args, **kwargs: None,
+ error=lambda *args, **kwargs: None,
)
CHOICE_SEPARATOR = "::n::"
@@ -330,6 +332,13 @@ def create_dummy_tool_parser(tokenizer, content="tool-text"):
class DummyToolParser:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
+ self._tool_prefix = ""
+ self._tool_prefix_token_ids = []
+ self._tool_prefix_computed = False
+ self._tool_prefix_injected_to_delta = False
+
+ def detect_tool_prefix(self, prompt):
+ return ""
def extract_tool_calls(self, full_text, response_dict):
# 模拟工具调用解析,返回固定的工具调用数据用于测试
@@ -570,7 +579,9 @@ def test_process_response_with_reasoning_and_tools(self):
"outputs": {"token_ids": [1, processor.tokenizer.eos_token_id]},
}
- processed = processor.process_response_dict(response, stream=False)
+ processed = processor.process_response_dict(
+ response, stream=False, request=SimpleNamespace(chat_template_kwargs=None)
+ )
self.assertEqual(processed["outputs"]["reasoning_content"], "think")
self.assertEqual(processed["outputs"]["tool_calls"], ["tool"])
@@ -597,7 +608,9 @@ def test_process_response_streaming_with_reasoning_and_tools(self):
"outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]},
}
- result = processor.process_response_dict_streaming(response, enable_thinking=True)
+ result = processor.process_response_dict_streaming(
+ response, enable_thinking=True, request=SimpleNamespace(chat_template_kwargs=None)
+ )
self.assertEqual(result["outputs"]["completion_tokens"], "7")
self.assertEqual(result["outputs"]["text"], "tool-text")
self.assertEqual(result["outputs"]["reasoning_content"], "because")
@@ -615,7 +628,9 @@ def test_process_response_dict_normal_with_reasoning(self):
"outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]},
}
- result = processor.process_response_dict_normal(response, enable_thinking=True)
+ result = processor.process_response_dict_normal(
+ response, enable_thinking=True, request=SimpleNamespace(chat_template_kwargs=None)
+ )
self.assertEqual(result["outputs"]["completion_tokens"], "7")
self.assertEqual(result["outputs"]["reasoning_content"], "because")
self.assertEqual(result["outputs"]["reasoning_token_num"], 1)
@@ -753,5 +768,213 @@ def custom_convert(tokens):
self.assertEqual(processor.update_bad_words(["combo", "oversize"], []), [])
+class _RecordingToolParser:
+ """Minimal tool parser that records inputs and exposes the prefix-state
+ fields the serving layer reads/writes."""
+
+ def __init__(self, tokenizer, tool_prefix="", detect_raises=False):
+ self.tokenizer = tokenizer
+ self._configured_prefix = tool_prefix
+ self._detect_raises = detect_raises
+ self._tool_prefix = ""
+ self._tool_prefix_computed = False
+ self._tool_prefix_injected_to_delta = False
+ self.detect_calls = []
+ self.extract_calls = []
+ self.streaming_calls = []
+
+ def detect_tool_prefix(self, prompt):
+ self.detect_calls.append(prompt)
+ if self._detect_raises:
+ raise RuntimeError("boom")
+ return self._configured_prefix if prompt and prompt.endswith(self._configured_prefix) else ""
+
+ def extract_tool_calls(self, model_output, request):
+ self.extract_calls.append(model_output)
+ return SimpleNamespace(tools_called=True, tool_calls=["tc"])
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ request,
+ ):
+ self.streaming_calls.append(
+ {
+ "previous_text": previous_text,
+ "current_text": current_text,
+ "delta_text": delta_text,
+ "previous_token_ids": list(previous_token_ids),
+ "current_token_ids": list(current_token_ids),
+ "delta_token_ids": list(delta_token_ids),
+ }
+ )
+ tool_calls = [
+ DeltaToolCall(
+ index=0,
+ type="function",
+ id="x",
+ function=DeltaFunctionCall(name="t").model_dump(exclude_none=True),
+ )
+ ]
+ return DeltaMessage(tool_calls=tool_calls, content="c")
+
+
+class ToolPrefixCompensationTest(unittest.TestCase):
+ """Tests for the forced-tool-call prefix compensation logic in
+ ``BaseTextProcessor``. Splicing is driven entirely by whether the
+ rendered prompt ends with an unclosed tool-call start token, not by
+ request parameter introspection."""
+
+ def setUp(self):
+ module, cleanup = _import_text_processor()
+ self.text_processor_module = module
+ self.addCleanup(cleanup)
+ self.processor = module.TextProcessor("stub-model")
+
+ def _make_parser_factory(self, parser):
+ return lambda tokenizer: parser
+
+ def test_prepare_tool_prefix_idempotent(self):
+ parser = _RecordingToolParser(self.processor.tokenizer)
+ prompt = "history\n"
+
+ self.processor._prepare_tool_prefix(parser, prompt)
+ self.assertTrue(parser._tool_prefix_computed)
+ self.assertEqual(parser._tool_prefix, "")
+ self.assertEqual(len(parser.detect_calls), 1)
+
+ # Second call must not invoke detect again.
+ self.processor._prepare_tool_prefix(parser, prompt)
+ self.assertEqual(len(parser.detect_calls), 1)
+
+ def test_prepare_tool_prefix_no_prompt(self):
+ parser = _RecordingToolParser(self.processor.tokenizer)
+ self.processor._prepare_tool_prefix(parser, None)
+ self.assertTrue(parser._tool_prefix_computed)
+ self.assertEqual(parser._tool_prefix, "")
+ self.assertEqual(parser.detect_calls, [])
+
+ parser2 = _RecordingToolParser(self.processor.tokenizer)
+ self.processor._prepare_tool_prefix(parser2, "")
+ self.assertEqual(parser2._tool_prefix, "")
+ self.assertEqual(parser2.detect_calls, [])
+
+ def test_prepare_tool_prefix_handles_exception(self):
+ parser = _RecordingToolParser(self.processor.tokenizer, detect_raises=True)
+ self.processor._prepare_tool_prefix(parser, "history\n")
+ self.assertTrue(parser._tool_prefix_computed)
+ self.assertEqual(parser._tool_prefix, "")
+
+ def test_normal_path_splices_prefix_when_prompt_has_prefix(self):
+ """Prompt ending with an unclosed tool-call start triggers splicing,
+ regardless of how the user requested it."""
+ processor = self.processor
+ parser = _RecordingToolParser(processor.tokenizer, tool_prefix="")
+ processor.tool_parser_obj = self._make_parser_factory(parser)
+
+ response = {
+ "request_id": "req-normal",
+ "finished": True,
+ "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]},
+ }
+ processor.process_response_dict_normal(
+ response,
+ request=SimpleNamespace(chat_template_kwargs=None),
+ prompt_tokens="user msg\n",
+ )
+ self.assertEqual(len(parser.extract_calls), 1)
+ # Model output is "7" after decoding token 7; prefix must be prepended.
+ self.assertTrue(parser.extract_calls[0].startswith(""))
+ self.assertEqual(response["outputs"]["tool_calls"], ["tc"])
+
+ def test_normal_path_no_splice_when_prompt_lacks_prefix(self):
+ """No prefix in prompt tail => detect returns "" => no splice."""
+ processor = self.processor
+ parser = _RecordingToolParser(processor.tokenizer, tool_prefix="")
+ processor.tool_parser_obj = self._make_parser_factory(parser)
+
+ response = {
+ "request_id": "req-auto",
+ "finished": True,
+ "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]},
+ }
+ processor.process_response_dict_normal(
+ response,
+ request=SimpleNamespace(chat_template_kwargs=None),
+ prompt_tokens="user msg without sentinel",
+ )
+ # detect_tool_prefix is called, but returns "" => no prefix prepended.
+ self.assertEqual(len(parser.detect_calls), 1)
+ self.assertFalse(parser.extract_calls[0].startswith(""))
+
+ def test_streaming_path_splices_prefix_only_on_first_delta(self):
+ processor = self.processor
+ parser = _RecordingToolParser(processor.tokenizer, tool_prefix="")
+ processor.tool_parser_obj = self._make_parser_factory(parser)
+ request = SimpleNamespace(chat_template_kwargs=None)
+ prompt_tokens = "user msg\n"
+
+ # First chunk
+ first = {
+ "finished": False,
+ "request_id": "stream-req",
+ "outputs": {"token_ids": [7]},
+ }
+ processor.process_response_dict_streaming(first, request=request, prompt_tokens=prompt_tokens)
+ first_call = parser.streaming_calls[0]
+ # delta_text decodes to "7"; previous="" current="7"
+ self.assertEqual(first_call["previous_text"], "")
+ self.assertEqual(first_call["current_text"], "7")
+ self.assertEqual(first_call["delta_text"], "7")
+ # token_ids must be spliced too — DummyTokenizer.encode("") -> [11].
+ prefix_ids = [11]
+ self.assertEqual(first_call["previous_token_ids"], prefix_ids)
+ self.assertEqual(first_call["current_token_ids"], prefix_ids + [7])
+ self.assertEqual(first_call["delta_token_ids"], prefix_ids + [7])
+ self.assertTrue(parser._tool_prefix_injected_to_delta)
+ self.assertEqual(parser._tool_prefix_token_ids, prefix_ids)
+
+ # Second chunk: delta must NOT be re-spliced, but previous/current are.
+ second = {
+ "finished": True,
+ "request_id": "stream-req",
+ "outputs": {"token_ids": [8, processor.tokenizer.eos_token_id]},
+ }
+ processor.process_response_dict_streaming(second, request=request, prompt_tokens=prompt_tokens)
+ second_call = parser.streaming_calls[1]
+ self.assertEqual(second_call["previous_text"], "7")
+ self.assertEqual(second_call["current_text"], "78")
+ self.assertEqual(second_call["delta_text"], "8") # no extra prefix splice
+ # ``is_end=True`` causes the eos token to be stripped before ids2tokens,
+ # so token_ids fed to the parser is just [8].
+ self.assertEqual(second_call["previous_token_ids"], prefix_ids + [7])
+ self.assertEqual(second_call["current_token_ids"], prefix_ids + [7, 8])
+ self.assertEqual(second_call["delta_token_ids"], [8])
+ # detect should only run once across the whole stream.
+ self.assertEqual(len(parser.detect_calls), 1)
+
+ def test_streaming_path_no_splice_when_no_prefix_detected(self):
+ processor = self.processor
+ # Empty configured prefix => detect returns "" even when prompt looks
+ # like a forced rendering.
+ parser = _RecordingToolParser(processor.tokenizer, tool_prefix="")
+ processor.tool_parser_obj = self._make_parser_factory(parser)
+ request = SimpleNamespace(chat_template_kwargs=None)
+
+ first = {
+ "finished": False,
+ "request_id": "stream-noprefix",
+ "outputs": {"token_ids": [7]},
+ }
+ processor.process_response_dict_streaming(first, request=request, prompt_tokens="no sentinel")
+ self.assertEqual(parser.streaming_calls[0]["delta_text"], "7")
+ self.assertFalse(parser._tool_prefix_injected_to_delta)
+
+
if __name__ == "__main__":
unittest.main()