Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,11 @@ def _send_error_response(self, request_id, error_msg, error_code: int = 500, wor
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
if envs.FD_ENABLE_RETURN_TEXT:
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
delta_text, previous_token_ids, _ = self.data_processor.ids2tokens(token_ids, req_id)
# Reconstruct the post-extend cumulative list from the pre-delta
# snapshot + this call's input — ``ids2tokens`` only returns the
# snapshot to keep its return values aliasing-free.
cum_tokens = previous_token_ids + list(token_ids)

This comment was marked as outdated.

if delta_text != "":
prefix_offset = self.data_processor.decode_status[req_id][0]
read_offset = self.data_processor.decode_status[req_id][1]
Expand Down
12 changes: 11 additions & 1 deletion fastdeploy/entrypoints/openai/response_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def accumulate_token_ids(self, request_output):
else:

This comment was marked as outdated.

self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})

async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output, request):
async def process_response_chat(
self, request_outputs, stream, include_stop_str_in_output, request, prompt_tokens=None
):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
Expand Down Expand Up @@ -101,6 +103,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
audio_tokens=all_audio_tokens,
tts=tts,
request=request,
prompt_tokens=prompt_tokens,

This comment was marked as outdated.

)
else:
response = self.data_processor.process_response_dict(
Expand All @@ -110,6 +113,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
audio_tokens=all_audio_tokens,
tts=tts,
request=request,
prompt_tokens=prompt_tokens,
)
yield response
elif decode_type == 2: # audio
Expand All @@ -128,13 +132,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
else:
response = self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
yield response
elif stream:
Expand Down Expand Up @@ -168,13 +174,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
else:
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
request_output["outputs"]["multipart"] = [text]
Expand All @@ -197,13 +205,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
stream=False,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
else:
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
Comment on lines 211 to 215
prompt_tokens=prompt_tokens,
Comment on lines 212 to +216
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
multipart.append(text)
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ async def chat_completion_stream_generator(
stream=True,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,

This comment was marked as outdated.

)

async for res in generator:
Expand Down Expand Up @@ -650,6 +651,7 @@ async def chat_completion_full_generator(
stream=False,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
prompt_tokens=prompt_tokens,
)
async for data in generator:
idx = get_choice_index(data["request_id"])
Expand Down
49 changes: 49 additions & 0 deletions fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class ToolParser:
derived classes.
"""

# Subclasses should override these with the literal tool-call sentinel
# tokens they recognize (e.g. ``"<tool_call>"`` / ``"</tool_call>"``).
# Used by :meth:`detect_tool_prefix` to support forced tool-call prompt
# prefix injection (named-tool ``tool_choice`` or
# ``chat_template_kwargs.options.tool_choice.mode == "force"``). Empty
# defaults make the detection a no-op for parsers that have not opted in.
tool_call_start_token: str = ""
tool_call_end_token: str = ""

def __init__(self, tokenizer):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
Expand All @@ -43,6 +52,16 @@ def __init__(self, tokenizer):

self.model_tokenizer = tokenizer

# Per-request tool-prefix state populated by the serving layer when
# the chat template injects a forced tool-call prefix into the prompt.
self._tool_prefix: str = ""
self._tool_prefix_token_ids: list[int] = []
# Set after the prefix is computed once for this request.
self._tool_prefix_computed: bool = False
# Set after the prefix has been spliced into the streaming delta
# (only the first chunk needs it).
self._tool_prefix_injected_to_delta: bool = False

@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
Expand All @@ -55,6 +74,36 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques
"""
return request

def detect_tool_prefix(self, prompt: str) -> str:
"""Detect the tool-call prefix injected at the tail of the rendered
prompt by a forced ``tool_choice``.

Finds the **last** :attr:`tool_call_start_token` in ``prompt`` that is
not closed by a later :attr:`tool_call_end_token` and reaches the
prompt end (modulo trailing whitespace). Returns ``""`` otherwise.
Subclasses with non-paired tag formats may override.
Comment on lines +81 to +84
"""
start = self.tool_call_start_token
if not start or not prompt:
return ""

last_start = prompt.rfind(start)
if last_start == -1:
return ""

end = self.tool_call_end_token
if end and prompt.find(end, last_start + len(start)) != -1:
# The last start token is closed — this is a historical, completed
# tool-call (e.g. from a previous assistant turn), not an injected
# forced prefix.
return ""

# By construction, ``prompt[last_start:]`` reaches the end of the
# prompt. We treat the whole tail as the injected prefix. Subclasses
# whose chat templates place additional content after the prefix can
# override this method to apply stricter validation.
return prompt[last_start:]

def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 detect_tool_prefix 对未闭合 start token 的判断过于宽松

prompt.rfind(start) 找到最后一个 start token 后,只检查其后是否有 end token,但未验证该 start token 是否真正位于 prompt 末尾附近。若历史对话中存在格式异常的未闭合 start token,而 prompt 末尾还有其他内容,则从该位置到 prompt 末尾的所有内容(可能非常长)会被当作 prefix 注入,导致 tool parser 接收到错误的解析输入。

建议修复方式:在 return prompt[last_start:] 前添加长度保护,超过合理阈值时视为误检返回 ""

_MAX_PREFIX_LEN = 512
tail = prompt[last_start:]
if len(tail) > _MAX_PREFIX_LEN:
    return ""
return tail

"""
Static method that should be implemented for extracting tool calls from

This comment was marked as outdated.

Expand Down
132 changes: 107 additions & 25 deletions fastdeploy/input/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,28 @@ def text2ids(self, text, max_model_len=None, **kwargs):
)
return tokens["input_ids"][0]

def _text_to_token_ids(self, text: str) -> list:
"""Encode ``text`` to a ``list[int]``, shared by :meth:`messages2ids`
and :meth:`_prepare_tool_prefix`.

``ernie4_5`` tokenizer hangs on long inputs via ``.encode()``, so it
goes through ``tokenize`` + ``convert_tokens_to_ids``. Other tokenizers
use ``.encode()`` and the result is normalized to a plain list.
"""
if self.tokenizer_type == "ernie4_5":
# NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
token_ids = token_ids["input_ids"]
if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
token_ids = token_ids[0]
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
if not isinstance(token_ids, list):
token_ids = list(token_ids)
return token_ids

def messages2ids(self, request, **kwargs):
"""Convert a chat-template request into a token-ID list.

Expand All @@ -159,19 +181,7 @@ def messages2ids(self, request, **kwargs):
)
request["prompt_tokens"] = spliced_message
req_id = request.get("request_id", None) if isinstance(request, dict) else None
if self.tokenizer_type == "ernie4_5":
# NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message))
else:
token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False)
if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
token_ids = token_ids["input_ids"]
if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
token_ids = token_ids[0]
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
if not isinstance(token_ids, list):
token_ids = list(token_ids)
token_ids = self._text_to_token_ids(spliced_message)
log_request(
level=1,
message="req_id:{req_id}, token_ids: {token_ids}",
Expand Down Expand Up @@ -204,9 +214,16 @@ def ids2tokens(self, token_id, task_id):
Returns:
(delta_text, previous_token_ids, previous_texts)

Both the HF and the PaddleFormers/ERNIE tokeniser paths return the
same tuple shape. The HF path sets ``previous_token_ids`` to ``[]``
since it does not expose per-token ids during batch-decode.
``previous_token_ids`` and ``previous_texts`` are **snapshots of the
accumulated state BEFORE this call's tokens were appended** —
symmetric pre-delta views of what the caller had decoded so far.
Both are owned by the caller (no aliasing of internal state).
Comment on lines +217 to +220

Callers that need the post-extend cumulative list should reconstruct
it locally via ``previous_token_ids + token_id``.

The HF path returns ``[]`` for ``previous_token_ids`` since it does
not expose per-token ids during batch-decode.
"""
if envs.FD_USE_HF_TOKENIZER:
if task_id not in self.decode_status:
Expand All @@ -225,20 +242,25 @@ def ids2tokens(self, token_id, task_id):
status[2] = decode_str[0]
else:
new_str = ""
# Return consistent three-tuple; previous_token_ids not available.
# NOTE: HF path historically returns the post-delta full string
# here, inconsistent with the non-HF branch (which returns the
# pre-delta snapshot). Preserved as-is to avoid behavior change.
return new_str, [], status[2]
else:
if task_id not in self.decode_status:
# [prefix_offset, read_offset, all_token_ids, accumulated_text]
self.decode_status[task_id] = [0, 0, [], ""]
status = self.decode_status[task_id]
previous_texts = status[3]
# Snapshot BEFORE extend so the returned list is owned by the
# caller and symmetric with ``previous_texts``.
previous_token_ids = list(status[2])
status[2].extend(token_id)
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1])
status[0] = prefix_offset
status[1] = read_offset
Comment on lines +257 to 261
status[3] += decode_str
return decode_str, status[2], previous_texts
return decode_str, previous_token_ids, previous_texts

This comment was marked as outdated.

This comment was marked as outdated.

# ------------------------------------------------------------------
# Response processing
Expand Down Expand Up @@ -266,6 +288,37 @@ def process_response_dict(self, response_dict, **kwargs):
else:
return self.process_response_dict_normal(response_dict, **kwargs)

def _prepare_tool_prefix(self, tool_parser, prompt_tokens):
"""Detect and cache on ``tool_parser`` the tool-call prefix that the
chat template injected at the tail of ``prompt_tokens`` (the rendered
prompt string from the serving layer). Computed once per parser
instance via the parser's :meth:`ToolParser.detect_tool_prefix`.
"""
if tool_parser._tool_prefix_computed:
return
tool_parser._tool_prefix_computed = True

This comment was marked as outdated.

tool_parser._tool_prefix = ""
tool_parser._tool_prefix_token_ids = []
if not prompt_tokens or not isinstance(prompt_tokens, str):
return
try:
prefix = tool_parser.detect_tool_prefix(prompt_tokens) or ""
except Exception:
data_processor_logger.exception("detect_tool_prefix failed; falling back to empty prefix")
return
tool_parser._tool_prefix = prefix
if not prefix:
return
# Encode the prefix into token ids so the streaming path can also
# splice ``previous/current/delta_token_ids`` — some parsers gate on
# ``tool_call_start_token_id in current_token_ids`` rather than on
# text (e.g. ``Ernie45VLThinkingToolParser``).
try:
tool_parser._tool_prefix_token_ids = self._text_to_token_ids(prefix)
except Exception:
data_processor_logger.exception("encode tool prefix to token ids failed; token-id splice disabled")
tool_parser._tool_prefix_token_ids = []
Comment on lines +316 to +320

def process_response_dict_normal(self, response_dict, **kwargs):
"""Accumulate tokens and build the full completion text (non-streaming)."""
token_ids = response_dict["outputs"]["token_ids"]
Expand Down Expand Up @@ -300,7 +353,11 @@ def process_response_dict_normal(self, response_dict, **kwargs):

if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
parser_input = full_text
self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"))
if tool_parser._tool_prefix:
parser_input = tool_parser._tool_prefix + full_text
tool_call_info = tool_parser.extract_tool_calls(parser_input, request)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls

Expand Down Expand Up @@ -354,13 +411,38 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
if req_id not in self.tool_parser_dict:
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parser_dict[req_id]
stream_previous = previous_texts
stream_current = previous_texts + delta_text
stream_delta = delta_text
stream_previous_token_ids = previous_token_ids
stream_current_token_ids = previous_token_ids + token_ids
stream_delta_token_ids = token_ids
self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"))
prefix = tool_parser._tool_prefix
prefix_ids = tool_parser._tool_prefix_token_ids
# Splice the injected prefix back into both text and token-id
# streaming args so parsers that gate on either form (e.g.
# ``Ernie45VLThinkingToolParser`` checks
# ``tool_call_start_token_id in current_token_ids``) work
# unchanged. ``delta_*`` only spliced on the first call.
if prefix:
stream_previous = prefix + stream_previous
stream_current = prefix + stream_current
if prefix_ids:
stream_previous_token_ids = list(prefix_ids) + list(stream_previous_token_ids)
stream_current_token_ids = list(prefix_ids) + list(stream_current_token_ids)
if not tool_parser._tool_prefix_injected_to_delta:
stream_delta = prefix + stream_delta
if prefix_ids:
stream_delta_token_ids = list(prefix_ids) + list(stream_delta_token_ids)
tool_parser._tool_prefix_injected_to_delta = True
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
stream_previous,
stream_current,
stream_delta,
stream_previous_token_ids,
stream_current_token_ids,
stream_delta_token_ids,
request,
)
if tool_call_delta_message:
Expand Down
Loading
Loading