Skip to content

Commit a2c7d81

Browse files
committed
Broader model compatibility, tool_choice support, bug fixes and cleanup
1 parent 87bbe0f commit a2c7d81

6 files changed

Lines changed: 254 additions & 82 deletions

File tree

backends/exllamav3/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,8 +1014,21 @@ async def generate_gen(
10141014
if chunk:
10151015
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
10161016
full_response += chunk
1017+
1018+
# Extract token IDs as a plain list for downstream consumers
10171019
if isinstance(chunk_tokens, torch.Tensor):
1020+
token_id_list = chunk_tokens.flatten().tolist()
10181021
generated_tokens += chunk_tokens.size(dim=0)
1022+
elif isinstance(chunk_tokens, tuple):
1023+
first = chunk_tokens[0]
1024+
if isinstance(first, torch.Tensor):
1025+
token_id_list = first.flatten().tolist()
1026+
else:
1027+
token_id_list = list(first)
1028+
generated_tokens += len(token_id_list)
1029+
else:
1030+
token_id_list = list(chunk_tokens)
1031+
generated_tokens += len(token_id_list)
10191032

10201033
# Increase penalty range to generated token amount
10211034
# TODO:
@@ -1025,6 +1038,7 @@ async def generate_gen(
10251038
generation = {
10261039
"request_id": request_id,
10271040
"text": chunk,
1041+
"token_ids": token_id_list,
10281042
"prompt_tokens": context_len,
10291043
"generated_tokens": generated_tokens,
10301044
"offset": len(full_response),

common/templating.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from jinja2.ext import loopcontrols
1313
from jinja2.sandbox import ImmutableSandboxedEnvironment
1414
from loguru import logger
15+
from markupsafe import Markup
1516
from packaging import version
1617

1718

@@ -33,6 +34,7 @@ class TemplateMetadata:
3334

3435
stop_strings: List[str] = field(default_factory=list)
3536
tool_start: Optional[str] = None
37+
tool_end: Optional[str] = None
3638
tool_call_format: str = "json"
3739

3840

@@ -50,6 +52,22 @@ class PromptTemplate:
5052
)
5153
metadata: Optional[TemplateMetadata] = None
5254

55+
@staticmethod
56+
def _tojson_compat(value, indent=None, ensure_ascii=True):
57+
"""Compatibility JSON filter for chat templates.
58+
59+
Some model templates call ``tojson(ensure_ascii=False)`` while the
60+
bundled Jinja filter may not accept that keyword in sandboxed mode.
61+
"""
62+
return Markup(
63+
json.dumps(
64+
value,
65+
indent=indent,
66+
ensure_ascii=ensure_ascii,
67+
separators=(",", ": "),
68+
)
69+
)
70+
5371
async def extract_metadata(self, template_vars: dict):
5472
"""
5573
Returns deserialized template metadata from a chat template.
@@ -80,6 +98,10 @@ async def extract_metadata(self, template_vars: dict):
8098
if isinstance(template_module.tool_start, str):
8199
template_metadata.tool_start = template_module.tool_start
82100

101+
if hasattr(template_module, "tool_end"):
102+
if isinstance(template_module.tool_end, str):
103+
template_metadata.tool_end = template_module.tool_end
104+
83105
if hasattr(template_module, "tool_call_format"):
84106
fmt = template_module.tool_call_format
85107
if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS:
@@ -123,6 +145,7 @@ def raise_exception(message):
123145

124146
self.environment.globals["strftime_now"] = strftime_now
125147
self.environment.globals["raise_exception"] = raise_exception
148+
self.environment.filters["tojson"] = self._tojson_compat
126149

127150
return self.environment.from_string(template_str)
128151

endpoints/OAI/types/chat_completion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from uuid import uuid4
55

66
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
7-
from endpoints.OAI.types.tools import ToolSpec, ToolCall
7+
from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall
88

99

1010
class ChatCompletionLogprob(BaseModel):
@@ -71,6 +71,10 @@ class ChatCompletionRequest(CommonCompletionRequest):
7171

7272
tools: Optional[List[ToolSpec]] = None
7373
functions: Optional[List[Dict]] = None
74+
tool_choice: Optional[
75+
Union[Literal["none", "auto", "required"], NamedToolChoice]
76+
] = None
77+
parallel_tool_calls: Optional[bool] = True
7478

7579
# Chat completions requests do not have a BOS token preference. Backend
7680
# respects the tokenization config for the individual model.

endpoints/OAI/types/tools.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,16 @@ class ToolCall(BaseModel):
4040
function: Tool
4141
type: Literal["function"] = "function"
4242
index: Optional[int] = None
43+
44+
45+
class NamedToolFunction(BaseModel):
46+
"""Represents a named function reference for tool_choice."""
47+
48+
name: str
49+
50+
51+
class NamedToolChoice(BaseModel):
52+
"""Represents a named tool choice (forces a specific function call)."""
53+
54+
function: NamedToolFunction
55+
type: Literal["function"] = "function"

endpoints/OAI/utils/chat_completion.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
ChatCompletionStreamChoice,
3030
)
3131
from endpoints.OAI.types.common import UsageStats
32-
from endpoints.OAI.types.tools import ToolCall
32+
from endpoints.OAI.types.tools import NamedToolChoice, ToolCall
3333
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
3434
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
3535

@@ -54,6 +54,7 @@ def _create_response(
5454
generations: List[dict],
5555
model_name: Optional[str],
5656
tool_call_format: str = "json",
57+
tool_choice=None,
5758
):
5859
"""Create a chat completion response from the provided text."""
5960

@@ -66,6 +67,10 @@ def _create_response(
6667
tool_calls_raw = generation.get("tool_calls")
6768
if tool_calls_raw:
6869
parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format)
70+
if parsed and isinstance(tool_choice, NamedToolChoice):
71+
parsed = ToolCallProcessor.filter_by_name(
72+
parsed, tool_choice.function.name
73+
)
6974
if parsed:
7075
message.tool_calls = parsed
7176
else:
@@ -488,7 +493,7 @@ async def stream_generate_chat_completion(
488493
raise CancelledError()
489494

490495
# Handle options if a tool model is present
491-
if tool_start:
496+
if tool_start and data.tool_choice != "none":
492497
if "stop_str" in generation:
493498
generations = await generate_tool_calls(
494499
prompt,
@@ -507,6 +512,10 @@ async def stream_generate_chat_completion(
507512
parsed = ToolCallProcessor.parse(
508513
tool_calls_raw, format=tool_call_format
509514
)
515+
if parsed and isinstance(data.tool_choice, NamedToolChoice):
516+
parsed = ToolCallProcessor.filter_by_name(
517+
parsed, data.tool_choice.function.name
518+
)
510519
if parsed:
511520
for tc_chunk in _build_tool_call_chunks(
512521
parsed,
@@ -616,7 +625,10 @@ async def generate_chat_completion(
616625
generations = await asyncio.gather(*gen_tasks)
617626

618627
# Check all the generations and see if a tool call is required
619-
if tool_start:
628+
force_tool_pass = data.tool_choice == "required" or isinstance(
629+
data.tool_choice, NamedToolChoice
630+
)
631+
if tool_start or force_tool_pass:
620632
generations = await generate_tool_calls(
621633
prompt, embeddings, data, generations, request
622634
)
@@ -626,6 +638,7 @@ async def generate_chat_completion(
626638
generations,
627639
model_path.name,
628640
tool_call_format=tool_call_format,
641+
tool_choice=data.tool_choice,
629642
)
630643

631644
logger.info(f"Finished chat completion request {request.state.id}")
@@ -652,6 +665,10 @@ async def generate_tool_calls(
652665
gen_tasks: List[asyncio.Task] = []
653666
tool_start = model.container.prompt_template.metadata.tool_start
654667
tool_call_format = model.container.prompt_template.metadata.tool_call_format
668+
tool_choice = data.tool_choice
669+
670+
if tool_choice == "none":
671+
return generations
655672

656673
# Tracks which generations asked for a tool call
657674
tool_idx: List[int] = []
@@ -684,29 +701,35 @@ async def generate_tool_calls(
684701
tool_data.json_schema = TOOL_CALL_SCHEMA
685702

686703
for idx, gen in enumerate(generations):
687-
if gen["stop_str"] != tool_start:
704+
stop_str = gen.get("stop_str")
705+
should_generate = stop_str == tool_start
706+
707+
# Force tool generation if tool_choice requires it
708+
if not should_generate and (
709+
tool_choice == "required" or isinstance(tool_choice, NamedToolChoice)
710+
):
711+
should_generate = True
712+
713+
if not should_generate:
688714
continue
689715

690716
logger.info(
691717
f"Detected tool call in chat completion request "
692718
f"{request.state.id} (format={tool_call_format})"
693719
)
694720

695-
# Append the existing generation text if present
721+
# Build per-generation prompt (avoid mutating shared prompt)
722+
tool_prompt = prompt
696723
precursor_text = gen.get("full_text")
697724
if precursor_text:
698-
prompt = prompt + precursor_text
725+
tool_prompt = tool_prompt + precursor_text
699726

700727
# For XML/auto mode: append tool_start back to prompt.
701728
# The stop string was consumed by the first pass and not included
702729
# in full_text, but the model expects to continue after <tool_call>.
703730
# Include a trailing newline to match the canonical template format.
704-
if tool_call_format in ("xml", "auto"):
705-
prompt = prompt + tool_start + "\n"
706-
logger.debug(
707-
f"generate_tool_calls: Appended '{tool_start}\\n' "
708-
f"to prompt for XML continuation"
709-
)
731+
if tool_call_format in ("xml", "auto") and tool_start:
732+
tool_prompt = tool_prompt + tool_start + "\n"
710733

711734
gen_request_id = gen.get("request_id")
712735
tool_request_id = f"{gen_request_id}-tool"
@@ -715,7 +738,7 @@ async def generate_tool_calls(
715738
asyncio.create_task(
716739
model.container.generate(
717740
tool_request_id,
718-
prompt,
741+
tool_prompt,
719742
tool_data,
720743
mm_embeddings=embeddings,
721744
)
@@ -734,10 +757,6 @@ async def generate_tool_calls(
734757
if tool_call_format in ("xml", "auto"):
735758
# Prepend tool_start to reconstruct complete XML for parser
736759
raw_text = tool_start + "\n" + raw_text
737-
logger.debug(
738-
f"generate_tool_calls: Raw XML tool call output "
739-
f"({len(raw_text)} chars): {raw_text[:500]}..."
740-
)
741760

742761
generations[gen_idx]["tool_calls"] = raw_text
743762

0 commit comments

Comments
 (0)