diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index d3f27122dba..494d2248380 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1909,7 +1909,11 @@ def _send_error_response(self, request_id, error_msg, error_code: int = 500, wor def _decode_token(self, token_ids, req_id, is_end): delta_text = "" if envs.FD_ENABLE_RETURN_TEXT: - delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id) + delta_text, previous_token_ids, _ = self.data_processor.ids2tokens(token_ids, req_id) + # Reconstruct the post-extend cumulative list from the pre-delta + # snapshot + this call's input — ``ids2tokens`` only returns the + # snapshot to keep its return values aliasing-free. + cum_tokens = previous_token_ids + list(token_ids) if delta_text != "": prefix_offset = self.data_processor.decode_status[req_id][0] read_offset = self.data_processor.decode_status[req_id][1] diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index ffaaf0f4aa5..2cfef290201 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -72,7 +72,9 @@ def accumulate_token_ids(self, request_output): else: self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output}) - async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output, request): + async def process_response_chat( + self, request_outputs, stream, include_stop_str_in_output, request, prompt_tokens=None + ): """ Process a list of responses into a generator that yields each processed response as it's generated. Args: @@ -101,6 +103,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ audio_tokens=all_audio_tokens, tts=tts, request=request, + prompt_tokens=prompt_tokens, ) else: response = self.data_processor.process_response_dict( @@ -110,6 +113,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ audio_tokens=all_audio_tokens, tts=tts, request=request, + prompt_tokens=prompt_tokens, ) yield response elif decode_type == 2: # audio @@ -128,6 +132,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=stream, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) else: response = self.data_processor.process_response_dict( @@ -135,6 +140,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=stream, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) yield response elif stream: @@ -168,6 +174,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=stream, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) else: self.data_processor.process_response_dict( @@ -175,6 +182,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=stream, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) text = {"type": "text", "text": request_output["outputs"]["text"]} request_output["outputs"]["multipart"] = [text] @@ -197,6 +205,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=False, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) else: self.data_processor.process_response_dict( @@ -204,6 +213,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ stream=stream, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) text = {"type": "text", "text": part["request_output"]["outputs"]["text"]} multipart.append(text) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index d6429521f05..25b77220d27 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -317,6 +317,7 @@ async def chat_completion_stream_generator( stream=True, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) async for res in generator: @@ -650,6 +651,7 @@ async def chat_completion_full_generator( stream=False, include_stop_str_in_output=include_stop_str_in_output, request=request, + prompt_tokens=prompt_tokens, ) async for data in generator: idx = get_choice_index(data["request_id"]) diff --git a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 906483f445a..461f702cd1b 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -34,6 +34,15 @@ class ToolParser: derived classes. """ + # Subclasses should override these with the literal tool-call sentinel + # tokens they recognize (e.g. ``""`` / ``""``). + # Used by :meth:`detect_tool_prefix` to support forced tool-call prompt + # prefix injection (named-tool ``tool_choice`` or + # ``chat_template_kwargs.options.tool_choice.mode == "force"``). Empty + # defaults make the detection a no-op for parsers that have not opted in. + tool_call_start_token: str = "" + tool_call_end_token: str = "" + def __init__(self, tokenizer): self.prev_tool_call_arr: list[dict] = [] # the index of the tool call that is currently being parsed @@ -43,6 +52,16 @@ def __init__(self, tokenizer): self.model_tokenizer = tokenizer + # Per-request tool-prefix state populated by the serving layer when + # the chat template injects a forced tool-call prefix into the prompt. + self._tool_prefix: str = "" + self._tool_prefix_token_ids: list[int] = [] + # Set after the prefix is computed once for this request. + self._tool_prefix_computed: bool = False + # Set after the prefix has been spliced into the streaming delta + # (only the first chunk needs it). + self._tool_prefix_injected_to_delta: bool = False + @cached_property def vocab(self) -> dict[str, int]: # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab @@ -55,6 +74,36 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques """ return request + def detect_tool_prefix(self, prompt: str) -> str: + """Detect the tool-call prefix injected at the tail of the rendered + prompt by a forced ``tool_choice``. + + Finds the **last** :attr:`tool_call_start_token` in ``prompt`` that is + not closed by a later :attr:`tool_call_end_token` and reaches the + prompt end (modulo trailing whitespace). Returns ``""`` otherwise. + Subclasses with non-paired tag formats may override. + """ + start = self.tool_call_start_token + if not start or not prompt: + return "" + + last_start = prompt.rfind(start) + if last_start == -1: + return "" + + end = self.tool_call_end_token + if end and prompt.find(end, last_start + len(start)) != -1: + # The last start token is closed — this is a historical, completed + # tool-call (e.g. from a previous assistant turn), not an injected + # forced prefix. + return "" + + # By construction, ``prompt[last_start:]`` reaches the end of the + # prompt. We treat the whole tail as the injected prefix. Subclasses + # whose chat templates place additional content after the prefix can + # override this method to apply stricter validation. + return prompt[last_start:] + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from diff --git a/fastdeploy/input/base_processor.py b/fastdeploy/input/base_processor.py index c65e0c42cf4..21c30fb19d7 100644 --- a/fastdeploy/input/base_processor.py +++ b/fastdeploy/input/base_processor.py @@ -138,6 +138,28 @@ def text2ids(self, text, max_model_len=None, **kwargs): ) return tokens["input_ids"][0] + def _text_to_token_ids(self, text: str) -> list: + """Encode ``text`` to a ``list[int]``, shared by :meth:`messages2ids` + and :meth:`_prepare_tool_prefix`. + + ``ernie4_5`` tokenizer hangs on long inputs via ``.encode()``, so it + goes through ``tokenize`` + ``convert_tokens_to_ids``. Other tokenizers + use ``.encode()`` and the result is normalized to a plain list. + """ + if self.tokenizer_type == "ernie4_5": + # NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode() + return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + token_ids = self.tokenizer.encode(text, add_special_tokens=False) + if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids): + token_ids = token_ids["input_ids"] + if hasattr(token_ids, "ndim") and token_ids.ndim > 1: + token_ids = token_ids[0] + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + if not isinstance(token_ids, list): + token_ids = list(token_ids) + return token_ids + def messages2ids(self, request, **kwargs): """Convert a chat-template request into a token-ID list. @@ -159,19 +181,7 @@ def messages2ids(self, request, **kwargs): ) request["prompt_tokens"] = spliced_message req_id = request.get("request_id", None) if isinstance(request, dict) else None - if self.tokenizer_type == "ernie4_5": - # NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode() - token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message)) - else: - token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False) - if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids): - token_ids = token_ids["input_ids"] - if hasattr(token_ids, "ndim") and token_ids.ndim > 1: - token_ids = token_ids[0] - if hasattr(token_ids, "tolist"): - token_ids = token_ids.tolist() - if not isinstance(token_ids, list): - token_ids = list(token_ids) + token_ids = self._text_to_token_ids(spliced_message) log_request( level=1, message="req_id:{req_id}, token_ids: {token_ids}", @@ -204,9 +214,16 @@ def ids2tokens(self, token_id, task_id): Returns: (delta_text, previous_token_ids, previous_texts) - Both the HF and the PaddleFormers/ERNIE tokeniser paths return the - same tuple shape. The HF path sets ``previous_token_ids`` to ``[]`` - since it does not expose per-token ids during batch-decode. + ``previous_token_ids`` and ``previous_texts`` are **snapshots of the + accumulated state BEFORE this call's tokens were appended** — + symmetric pre-delta views of what the caller had decoded so far. + Both are owned by the caller (no aliasing of internal state). + + Callers that need the post-extend cumulative list should reconstruct + it locally via ``previous_token_ids + token_id``. + + The HF path returns ``[]`` for ``previous_token_ids`` since it does + not expose per-token ids during batch-decode. """ if envs.FD_USE_HF_TOKENIZER: if task_id not in self.decode_status: @@ -225,7 +242,9 @@ def ids2tokens(self, token_id, task_id): status[2] = decode_str[0] else: new_str = "" - # Return consistent three-tuple; previous_token_ids not available. + # NOTE: HF path historically returns the post-delta full string + # here, inconsistent with the non-HF branch (which returns the + # pre-delta snapshot). Preserved as-is to avoid behavior change. return new_str, [], status[2] else: if task_id not in self.decode_status: @@ -233,12 +252,15 @@ def ids2tokens(self, token_id, task_id): self.decode_status[task_id] = [0, 0, [], ""] status = self.decode_status[task_id] previous_texts = status[3] + # Snapshot BEFORE extend so the returned list is owned by the + # caller and symmetric with ``previous_texts``. + previous_token_ids = list(status[2]) status[2].extend(token_id) decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1]) status[0] = prefix_offset status[1] = read_offset status[3] += decode_str - return decode_str, status[2], previous_texts + return decode_str, previous_token_ids, previous_texts # ------------------------------------------------------------------ # Response processing @@ -266,6 +288,37 @@ def process_response_dict(self, response_dict, **kwargs): else: return self.process_response_dict_normal(response_dict, **kwargs) + def _prepare_tool_prefix(self, tool_parser, prompt_tokens): + """Detect and cache on ``tool_parser`` the tool-call prefix that the + chat template injected at the tail of ``prompt_tokens`` (the rendered + prompt string from the serving layer). Computed once per parser + instance via the parser's :meth:`ToolParser.detect_tool_prefix`. + """ + if tool_parser._tool_prefix_computed: + return + tool_parser._tool_prefix_computed = True + tool_parser._tool_prefix = "" + tool_parser._tool_prefix_token_ids = [] + if not prompt_tokens or not isinstance(prompt_tokens, str): + return + try: + prefix = tool_parser.detect_tool_prefix(prompt_tokens) or "" + except Exception: + data_processor_logger.exception("detect_tool_prefix failed; falling back to empty prefix") + return + tool_parser._tool_prefix = prefix + if not prefix: + return + # Encode the prefix into token ids so the streaming path can also + # splice ``previous/current/delta_token_ids`` — some parsers gate on + # ``tool_call_start_token_id in current_token_ids`` rather than on + # text (e.g. ``Ernie45VLThinkingToolParser``). + try: + tool_parser._tool_prefix_token_ids = self._text_to_token_ids(prefix) + except Exception: + data_processor_logger.exception("encode tool prefix to token ids failed; token-id splice disabled") + tool_parser._tool_prefix_token_ids = [] + def process_response_dict_normal(self, response_dict, **kwargs): """Accumulate tokens and build the full completion text (non-streaming).""" token_ids = response_dict["outputs"]["token_ids"] @@ -300,7 +353,11 @@ def process_response_dict_normal(self, response_dict, **kwargs): if self.tool_parser_obj: tool_parser = self.tool_parser_obj(self.tokenizer) - tool_call_info = tool_parser.extract_tool_calls(full_text, request) + parser_input = full_text + self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens")) + if tool_parser._tool_prefix: + parser_input = tool_parser._tool_prefix + full_text + tool_call_info = tool_parser.extract_tool_calls(parser_input, request) if tool_call_info.tools_called: response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls @@ -354,13 +411,38 @@ def process_response_dict_streaming(self, response_dict, **kwargs): if req_id not in self.tool_parser_dict: self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer) tool_parser = self.tool_parser_dict[req_id] + stream_previous = previous_texts + stream_current = previous_texts + delta_text + stream_delta = delta_text + stream_previous_token_ids = previous_token_ids + stream_current_token_ids = previous_token_ids + token_ids + stream_delta_token_ids = token_ids + self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens")) + prefix = tool_parser._tool_prefix + prefix_ids = tool_parser._tool_prefix_token_ids + # Splice the injected prefix back into both text and token-id + # streaming args so parsers that gate on either form (e.g. + # ``Ernie45VLThinkingToolParser`` checks + # ``tool_call_start_token_id in current_token_ids``) work + # unchanged. ``delta_*`` only spliced on the first call. + if prefix: + stream_previous = prefix + stream_previous + stream_current = prefix + stream_current + if prefix_ids: + stream_previous_token_ids = list(prefix_ids) + list(stream_previous_token_ids) + stream_current_token_ids = list(prefix_ids) + list(stream_current_token_ids) + if not tool_parser._tool_prefix_injected_to_delta: + stream_delta = prefix + stream_delta + if prefix_ids: + stream_delta_token_ids = list(prefix_ids) + list(stream_delta_token_ids) + tool_parser._tool_prefix_injected_to_delta = True tool_call_delta_message = tool_parser.extract_tool_calls_streaming( - previous_texts, - previous_texts + delta_text, - delta_text, - previous_token_ids, - previous_token_ids + token_ids, - token_ids, + stream_previous, + stream_current, + stream_delta, + stream_previous_token_ids, + stream_current_token_ids, + stream_delta_token_ids, request, ) if tool_call_delta_message: diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index a3487133bcb..3f05797a37a 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -752,7 +752,9 @@ def __init__(self): self.decode_status = {"rid": (0, 2)} def ids2tokens(self, token_ids, req_id): - return "hi", [101, 102], None + # previous_token_ids snapshot is empty (first call); engine + # reconstructs cum = previous + input = [101, 102]. + return "hi", [], None eng.data_processor = DummyProcessor() @@ -782,7 +784,8 @@ def __init__(self): self.decode_status = {"rid": (0, 1)} def ids2tokens(self, token_ids, req_id): - return "", [7], None + # previous snapshot is empty; cum becomes [7]. + return "", [], None eng.data_processor = DummyProcessor() @@ -1975,7 +1978,8 @@ def __init__(self): self.decode_status = {"rid": (0, 2)} def ids2tokens(self, token_ids, req_id): - return "hi", [1, 2], None + # previous snapshot empty; cum = [] + [1, 2] = [1, 2]. + return "hi", [], None eng.data_processor = DummyProcessor() @@ -3453,7 +3457,8 @@ def __init__(self): self.decode_status = {"tok-req": (1, 3)} def ids2tokens(self, token_ids, req_id): - return "hello", [10, 20, 30], None + # previous snapshot empty; cum = [] + [10, 20, 30]. + return "hello", [], None eng.data_processor = DummyProcessor() diff --git a/tests/entrypoints/openai/test_finish_reason.py b/tests/entrypoints/openai/test_finish_reason.py index 067b80ca0e5..74ce54e21cd 100644 --- a/tests/entrypoints/openai/test_finish_reason.py +++ b/tests/entrypoints/openai/test_finish_reason.py @@ -262,7 +262,9 @@ async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class mock_processor_instance = Mock() mock_processor_instance.enable_multimodal_content.return_value = True - async def mock_process_response_chat_async(response, stream, include_stop_str_in_output, request=None): + async def mock_process_response_chat_async( + response, stream, include_stop_str_in_output, request=None, prompt_tokens=None + ): yield response mock_processor_instance.process_response_chat = mock_process_response_chat_async @@ -445,7 +447,9 @@ async def test_chat_stream_max_tokens(self, mock_api_logger, mock_processor_clas mock_processor_instance = Mock() mock_processor_instance.enable_multimodal_content.return_value = False - async def mock_process_response_chat_async(response, stream, include_stop_str_in_output, request=None): + async def mock_process_response_chat_async( + response, stream, include_stop_str_in_output, request=None, prompt_tokens=None + ): if isinstance(response, list): for res in response: yield res diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 63db437cc5d..c2efcdd03a0 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -222,7 +222,9 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class mock_processor_instance = Mock() - async def mock_process_response_chat_single(response, stream, include_stop_str_in_output, request=None): + async def mock_process_response_chat_single( + response, stream, include_stop_str_in_output, request=None, prompt_tokens=None + ): yield response mock_processor_instance.process_response_chat = mock_process_response_chat_single @@ -639,7 +641,9 @@ async def test_chat_stream_usage_fields(self, mock_response_processor, api_serve mock_processor_instance = Mock() - async def mock_process_response_chat(response, stream, include_stop_str_in_output, request=None): + async def mock_process_response_chat( + response, stream, include_stop_str_in_output, request=None, prompt_tokens=None + ): delta_msg_mock = Mock() delta_msg_mock.content = response["outputs"]["text"] if response["outputs"]["text"] == "a": diff --git a/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py new file mode 100644 index 00000000000..d8fd3acec0f --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_abstract_tool_parser.py @@ -0,0 +1,99 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser + + +class _DummyTokenizer: + def get_vocab(self): + return {} + + +class _PairedTagParser(ToolParser): + """A concrete parser declaring paired sentinel tokens for testing.""" + + tool_call_start_token = "" + tool_call_end_token = "" + + +class _NoSentinelParser(ToolParser): + """A parser that did not opt in to prefix detection.""" + + +class TestDetectToolPrefix(unittest.TestCase): + def setUp(self): + self.tokenizer = _DummyTokenizer() + self.parser = _PairedTagParser(self.tokenizer) + + def test_initial_state(self): + self.assertEqual(self.parser._tool_prefix, "") + self.assertFalse(self.parser._tool_prefix_computed) + self.assertFalse(self.parser._tool_prefix_injected_to_delta) + + def test_empty_prompt_returns_empty(self): + self.assertEqual(self.parser.detect_tool_prefix(""), "") + + def test_no_start_token_returns_empty(self): + self.assertEqual( + self.parser.detect_tool_prefix("user: hello\nassistant: hi"), + "", + ) + + def test_parser_without_sentinel_returns_empty(self): + parser = _NoSentinelParser(self.tokenizer) + self.assertEqual( + parser.detect_tool_prefix("anything here"), + "", + ) + + def test_trailing_start_token_only(self): + prompt = "user: q\n" + self.assertEqual(self.parser.detect_tool_prefix(prompt), "") + + def test_trailing_start_with_invoke_prefix(self): + prompt = "history\n{...}\nuser: next" + self.assertEqual(self.parser.detect_tool_prefix(prompt), "") + + def test_history_closed_plus_new_injected_prefix(self): + prompt = "{a:1}\n{a:1}\n" "{b:2}\n" "assistant: done" + self.assertEqual(self.parser.detect_tool_prefix(prompt), "") + + def test_trailing_whitespace_after_start(self): + prompt = "history\n " + self.assertEqual( + self.parser.detect_tool_prefix(prompt), + " ", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py index ebb4c9ff127..940fe51ec46 100644 --- a/tests/input/test_text_processor.py +++ b/tests/input/test_text_processor.py @@ -130,6 +130,8 @@ def _create_dummy_modules(): info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None, debug=lambda *args, **kwargs: None, + exception=lambda *args, **kwargs: None, + error=lambda *args, **kwargs: None, ) CHOICE_SEPARATOR = "::n::" @@ -330,6 +332,13 @@ def create_dummy_tool_parser(tokenizer, content="tool-text"): class DummyToolParser: def __init__(self, tokenizer): self.tokenizer = tokenizer + self._tool_prefix = "" + self._tool_prefix_token_ids = [] + self._tool_prefix_computed = False + self._tool_prefix_injected_to_delta = False + + def detect_tool_prefix(self, prompt): + return "" def extract_tool_calls(self, full_text, response_dict): # 模拟工具调用解析,返回固定的工具调用数据用于测试 @@ -570,7 +579,9 @@ def test_process_response_with_reasoning_and_tools(self): "outputs": {"token_ids": [1, processor.tokenizer.eos_token_id]}, } - processed = processor.process_response_dict(response, stream=False) + processed = processor.process_response_dict( + response, stream=False, request=SimpleNamespace(chat_template_kwargs=None) + ) self.assertEqual(processed["outputs"]["reasoning_content"], "think") self.assertEqual(processed["outputs"]["tool_calls"], ["tool"]) @@ -597,7 +608,9 @@ def test_process_response_streaming_with_reasoning_and_tools(self): "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]}, } - result = processor.process_response_dict_streaming(response, enable_thinking=True) + result = processor.process_response_dict_streaming( + response, enable_thinking=True, request=SimpleNamespace(chat_template_kwargs=None) + ) self.assertEqual(result["outputs"]["completion_tokens"], "7") self.assertEqual(result["outputs"]["text"], "tool-text") self.assertEqual(result["outputs"]["reasoning_content"], "because") @@ -615,7 +628,9 @@ def test_process_response_dict_normal_with_reasoning(self): "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]}, } - result = processor.process_response_dict_normal(response, enable_thinking=True) + result = processor.process_response_dict_normal( + response, enable_thinking=True, request=SimpleNamespace(chat_template_kwargs=None) + ) self.assertEqual(result["outputs"]["completion_tokens"], "7") self.assertEqual(result["outputs"]["reasoning_content"], "because") self.assertEqual(result["outputs"]["reasoning_token_num"], 1) @@ -753,5 +768,213 @@ def custom_convert(tokens): self.assertEqual(processor.update_bad_words(["combo", "oversize"], []), []) +class _RecordingToolParser: + """Minimal tool parser that records inputs and exposes the prefix-state + fields the serving layer reads/writes.""" + + def __init__(self, tokenizer, tool_prefix="", detect_raises=False): + self.tokenizer = tokenizer + self._configured_prefix = tool_prefix + self._detect_raises = detect_raises + self._tool_prefix = "" + self._tool_prefix_computed = False + self._tool_prefix_injected_to_delta = False + self.detect_calls = [] + self.extract_calls = [] + self.streaming_calls = [] + + def detect_tool_prefix(self, prompt): + self.detect_calls.append(prompt) + if self._detect_raises: + raise RuntimeError("boom") + return self._configured_prefix if prompt and prompt.endswith(self._configured_prefix) else "" + + def extract_tool_calls(self, model_output, request): + self.extract_calls.append(model_output) + return SimpleNamespace(tools_called=True, tool_calls=["tc"]) + + def extract_tool_calls_streaming( + self, + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request, + ): + self.streaming_calls.append( + { + "previous_text": previous_text, + "current_text": current_text, + "delta_text": delta_text, + "previous_token_ids": list(previous_token_ids), + "current_token_ids": list(current_token_ids), + "delta_token_ids": list(delta_token_ids), + } + ) + tool_calls = [ + DeltaToolCall( + index=0, + type="function", + id="x", + function=DeltaFunctionCall(name="t").model_dump(exclude_none=True), + ) + ] + return DeltaMessage(tool_calls=tool_calls, content="c") + + +class ToolPrefixCompensationTest(unittest.TestCase): + """Tests for the forced-tool-call prefix compensation logic in + ``BaseTextProcessor``. Splicing is driven entirely by whether the + rendered prompt ends with an unclosed tool-call start token, not by + request parameter introspection.""" + + def setUp(self): + module, cleanup = _import_text_processor() + self.text_processor_module = module + self.addCleanup(cleanup) + self.processor = module.TextProcessor("stub-model") + + def _make_parser_factory(self, parser): + return lambda tokenizer: parser + + def test_prepare_tool_prefix_idempotent(self): + parser = _RecordingToolParser(self.processor.tokenizer) + prompt = "history\n" + + self.processor._prepare_tool_prefix(parser, prompt) + self.assertTrue(parser._tool_prefix_computed) + self.assertEqual(parser._tool_prefix, "") + self.assertEqual(len(parser.detect_calls), 1) + + # Second call must not invoke detect again. + self.processor._prepare_tool_prefix(parser, prompt) + self.assertEqual(len(parser.detect_calls), 1) + + def test_prepare_tool_prefix_no_prompt(self): + parser = _RecordingToolParser(self.processor.tokenizer) + self.processor._prepare_tool_prefix(parser, None) + self.assertTrue(parser._tool_prefix_computed) + self.assertEqual(parser._tool_prefix, "") + self.assertEqual(parser.detect_calls, []) + + parser2 = _RecordingToolParser(self.processor.tokenizer) + self.processor._prepare_tool_prefix(parser2, "") + self.assertEqual(parser2._tool_prefix, "") + self.assertEqual(parser2.detect_calls, []) + + def test_prepare_tool_prefix_handles_exception(self): + parser = _RecordingToolParser(self.processor.tokenizer, detect_raises=True) + self.processor._prepare_tool_prefix(parser, "history\n") + self.assertTrue(parser._tool_prefix_computed) + self.assertEqual(parser._tool_prefix, "") + + def test_normal_path_splices_prefix_when_prompt_has_prefix(self): + """Prompt ending with an unclosed tool-call start triggers splicing, + regardless of how the user requested it.""" + processor = self.processor + parser = _RecordingToolParser(processor.tokenizer, tool_prefix="") + processor.tool_parser_obj = self._make_parser_factory(parser) + + response = { + "request_id": "req-normal", + "finished": True, + "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]}, + } + processor.process_response_dict_normal( + response, + request=SimpleNamespace(chat_template_kwargs=None), + prompt_tokens="user msg\n", + ) + self.assertEqual(len(parser.extract_calls), 1) + # Model output is "7" after decoding token 7; prefix must be prepended. + self.assertTrue(parser.extract_calls[0].startswith("")) + self.assertEqual(response["outputs"]["tool_calls"], ["tc"]) + + def test_normal_path_no_splice_when_prompt_lacks_prefix(self): + """No prefix in prompt tail => detect returns "" => no splice.""" + processor = self.processor + parser = _RecordingToolParser(processor.tokenizer, tool_prefix="") + processor.tool_parser_obj = self._make_parser_factory(parser) + + response = { + "request_id": "req-auto", + "finished": True, + "outputs": {"token_ids": [7, processor.tokenizer.eos_token_id]}, + } + processor.process_response_dict_normal( + response, + request=SimpleNamespace(chat_template_kwargs=None), + prompt_tokens="user msg without sentinel", + ) + # detect_tool_prefix is called, but returns "" => no prefix prepended. + self.assertEqual(len(parser.detect_calls), 1) + self.assertFalse(parser.extract_calls[0].startswith("")) + + def test_streaming_path_splices_prefix_only_on_first_delta(self): + processor = self.processor + parser = _RecordingToolParser(processor.tokenizer, tool_prefix="") + processor.tool_parser_obj = self._make_parser_factory(parser) + request = SimpleNamespace(chat_template_kwargs=None) + prompt_tokens = "user msg\n" + + # First chunk + first = { + "finished": False, + "request_id": "stream-req", + "outputs": {"token_ids": [7]}, + } + processor.process_response_dict_streaming(first, request=request, prompt_tokens=prompt_tokens) + first_call = parser.streaming_calls[0] + # delta_text decodes to "7"; previous="" current="7" + self.assertEqual(first_call["previous_text"], "") + self.assertEqual(first_call["current_text"], "7") + self.assertEqual(first_call["delta_text"], "7") + # token_ids must be spliced too — DummyTokenizer.encode("") -> [11]. + prefix_ids = [11] + self.assertEqual(first_call["previous_token_ids"], prefix_ids) + self.assertEqual(first_call["current_token_ids"], prefix_ids + [7]) + self.assertEqual(first_call["delta_token_ids"], prefix_ids + [7]) + self.assertTrue(parser._tool_prefix_injected_to_delta) + self.assertEqual(parser._tool_prefix_token_ids, prefix_ids) + + # Second chunk: delta must NOT be re-spliced, but previous/current are. + second = { + "finished": True, + "request_id": "stream-req", + "outputs": {"token_ids": [8, processor.tokenizer.eos_token_id]}, + } + processor.process_response_dict_streaming(second, request=request, prompt_tokens=prompt_tokens) + second_call = parser.streaming_calls[1] + self.assertEqual(second_call["previous_text"], "7") + self.assertEqual(second_call["current_text"], "78") + self.assertEqual(second_call["delta_text"], "8") # no extra prefix splice + # ``is_end=True`` causes the eos token to be stripped before ids2tokens, + # so token_ids fed to the parser is just [8]. + self.assertEqual(second_call["previous_token_ids"], prefix_ids + [7]) + self.assertEqual(second_call["current_token_ids"], prefix_ids + [7, 8]) + self.assertEqual(second_call["delta_token_ids"], [8]) + # detect should only run once across the whole stream. + self.assertEqual(len(parser.detect_calls), 1) + + def test_streaming_path_no_splice_when_no_prefix_detected(self): + processor = self.processor + # Empty configured prefix => detect returns "" even when prompt looks + # like a forced rendering. + parser = _RecordingToolParser(processor.tokenizer, tool_prefix="") + processor.tool_parser_obj = self._make_parser_factory(parser) + request = SimpleNamespace(chat_template_kwargs=None) + + first = { + "finished": False, + "request_id": "stream-noprefix", + "outputs": {"token_ids": [7]}, + } + processor.process_response_dict_streaming(first, request=request, prompt_tokens="no sentinel") + self.assertEqual(parser.streaming_calls[0]["delta_text"], "7") + self.assertFalse(parser._tool_prefix_injected_to_delta) + + if __name__ == "__main__": unittest.main()