-
Notifications
You must be signed in to change notification settings - Fork 743
[Feature] Add OpenAI-compatible tool_choice support for chat completions #7882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
6b2f806
47a7e23
7041f47
3017744
2ac3eb9
01eb97c
7c5af98
2d06f3e
2b12c9a
0cafd3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,15 @@ class ToolParser: | |
| derived classes. | ||
| """ | ||
|
|
||
| # Subclasses should override these with the literal tool-call sentinel | ||
| # tokens they recognize (e.g. ``"<tool_call>"`` / ``"</tool_call>"``). | ||
| # Used by :meth:`detect_tool_prefix` to support forced tool-call prompt | ||
| # prefix injection (named-tool ``tool_choice`` or | ||
| # ``chat_template_kwargs.options.tool_choice.mode == "force"``). Empty | ||
| # defaults make the detection a no-op for parsers that have not opted in. | ||
| tool_call_start_token: str = "" | ||
| tool_call_end_token: str = "" | ||
|
|
||
| def __init__(self, tokenizer): | ||
| self.prev_tool_call_arr: list[dict] = [] | ||
| # the index of the tool call that is currently being parsed | ||
|
|
@@ -43,6 +52,16 @@ def __init__(self, tokenizer): | |
|
|
||
| self.model_tokenizer = tokenizer | ||
|
|
||
| # Per-request tool-prefix state populated by the serving layer when | ||
| # the chat template injects a forced tool-call prefix into the prompt. | ||
| self._tool_prefix: str = "" | ||
| self._tool_prefix_token_ids: list[int] = [] | ||
| # Set after the prefix is computed once for this request. | ||
| self._tool_prefix_computed: bool = False | ||
| # Set after the prefix has been spliced into the streaming delta | ||
| # (only the first chunk needs it). | ||
| self._tool_prefix_injected_to_delta: bool = False | ||
|
|
||
| @cached_property | ||
| def vocab(self) -> dict[str, int]: | ||
| # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab | ||
|
|
@@ -55,6 +74,36 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques | |
| """ | ||
| return request | ||
|
|
||
| def detect_tool_prefix(self, prompt: str) -> str: | ||
| """Detect the tool-call prefix injected at the tail of the rendered | ||
| prompt by a forced ``tool_choice``. | ||
|
|
||
| Finds the **last** :attr:`tool_call_start_token` in ``prompt`` that is | ||
| not closed by a later :attr:`tool_call_end_token` and reaches the | ||
| prompt end (modulo trailing whitespace). Returns ``""`` otherwise. | ||
| Subclasses with non-paired tag formats may override. | ||
|
Comment on lines
+81
to
+84
|
||
| """ | ||
| start = self.tool_call_start_token | ||
| if not start or not prompt: | ||
| return "" | ||
|
|
||
| last_start = prompt.rfind(start) | ||
| if last_start == -1: | ||
| return "" | ||
|
|
||
| end = self.tool_call_end_token | ||
| if end and prompt.find(end, last_start + len(start)) != -1: | ||
| # The last start token is closed — this is a historical, completed | ||
| # tool-call (e.g. from a previous assistant turn), not an injected | ||
| # forced prefix. | ||
| return "" | ||
|
|
||
| # By construction, ``prompt[last_start:]`` reaches the end of the | ||
| # prompt. We treat the whole tail as the injected prefix. Subclasses | ||
| # whose chat templates place additional content after the prefix can | ||
| # override this method to apply stricter validation. | ||
| return prompt[last_start:] | ||
|
|
||
| def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议
建议修复方式:在 _MAX_PREFIX_LEN = 512
tail = prompt[last_start:]
if len(tail) > _MAX_PREFIX_LEN:
return ""
return tail |
||
| """ | ||
| Static method that should be implemented for extracting tool calls from | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.