2929 ChatCompletionStreamChoice ,
3030)
3131from endpoints .OAI .types .common import UsageStats
32- from endpoints .OAI .types .tools import ToolCall
32+ from endpoints .OAI .types .tools import NamedToolChoice , ToolCall
3333from endpoints .OAI .utils .completion import _parse_gen_request_id , _stream_collector
3434from endpoints .OAI .utils .tools import ToolCallProcessor , TOOL_CALL_SCHEMA
3535
@@ -54,6 +54,7 @@ def _create_response(
5454 generations : List [dict ],
5555 model_name : Optional [str ],
5656 tool_call_format : str = "json" ,
57+ tool_choice = None ,
5758):
5859 """Create a chat completion response from the provided text."""
5960
@@ -66,6 +67,10 @@ def _create_response(
6667 tool_calls_raw = generation .get ("tool_calls" )
6768 if tool_calls_raw :
6869 parsed = ToolCallProcessor .parse (tool_calls_raw , format = tool_call_format )
70+ if parsed and isinstance (tool_choice , NamedToolChoice ):
71+ parsed = ToolCallProcessor .filter_by_name (
72+ parsed , tool_choice .function .name
73+ )
6974 if parsed :
7075 message .tool_calls = parsed
7176 else :
@@ -488,7 +493,7 @@ async def stream_generate_chat_completion(
488493 raise CancelledError ()
489494
490495 # Handle options if a tool model is present
491- if tool_start :
496+ if tool_start and data . tool_choice != "none" :
492497 if "stop_str" in generation :
493498 generations = await generate_tool_calls (
494499 prompt ,
@@ -507,6 +512,10 @@ async def stream_generate_chat_completion(
507512 parsed = ToolCallProcessor .parse (
508513 tool_calls_raw , format = tool_call_format
509514 )
515+ if parsed and isinstance (data .tool_choice , NamedToolChoice ):
516+ parsed = ToolCallProcessor .filter_by_name (
517+ parsed , data .tool_choice .function .name
518+ )
510519 if parsed :
511520 for tc_chunk in _build_tool_call_chunks (
512521 parsed ,
@@ -616,7 +625,10 @@ async def generate_chat_completion(
616625 generations = await asyncio .gather (* gen_tasks )
617626
618627 # Check all the generations and see if a tool call is required
619- if tool_start :
628+ force_tool_pass = data .tool_choice == "required" or isinstance (
629+ data .tool_choice , NamedToolChoice
630+ )
631+ if tool_start or force_tool_pass :
620632 generations = await generate_tool_calls (
621633 prompt , embeddings , data , generations , request
622634 )
@@ -626,6 +638,7 @@ async def generate_chat_completion(
626638 generations ,
627639 model_path .name ,
628640 tool_call_format = tool_call_format ,
641+ tool_choice = data .tool_choice ,
629642 )
630643
631644 logger .info (f"Finished chat completion request { request .state .id } " )
@@ -652,6 +665,10 @@ async def generate_tool_calls(
652665 gen_tasks : List [asyncio .Task ] = []
653666 tool_start = model .container .prompt_template .metadata .tool_start
654667 tool_call_format = model .container .prompt_template .metadata .tool_call_format
668+ tool_choice = data .tool_choice
669+
670+ if tool_choice == "none" :
671+ return generations
655672
656673 # Tracks which generations asked for a tool call
657674 tool_idx : List [int ] = []
@@ -684,29 +701,35 @@ async def generate_tool_calls(
684701 tool_data .json_schema = TOOL_CALL_SCHEMA
685702
686703 for idx , gen in enumerate (generations ):
687- if gen ["stop_str" ] != tool_start :
704+ stop_str = gen .get ("stop_str" )
705+ should_generate = stop_str == tool_start
706+
707+ # Force tool generation if tool_choice requires it
708+ if not should_generate and (
709+ tool_choice == "required" or isinstance (tool_choice , NamedToolChoice )
710+ ):
711+ should_generate = True
712+
713+ if not should_generate :
688714 continue
689715
690716 logger .info (
691717 f"Detected tool call in chat completion request "
692718 f"{ request .state .id } (format={ tool_call_format } )"
693719 )
694720
695- # Append the existing generation text if present
721+ # Build per-generation prompt (avoid mutating shared prompt)
722+ tool_prompt = prompt
696723 precursor_text = gen .get ("full_text" )
697724 if precursor_text :
698- prompt = prompt + precursor_text
725+ tool_prompt = tool_prompt + precursor_text
699726
700727 # For XML/auto mode: append tool_start back to prompt.
701728 # The stop string was consumed by the first pass and not included
702729 # in full_text, but the model expects to continue after <tool_call>.
703730 # Include a trailing newline to match the canonical template format.
704- if tool_call_format in ("xml" , "auto" ):
705- prompt = prompt + tool_start + "\n "
706- logger .debug (
707- f"generate_tool_calls: Appended '{ tool_start } \\ n' "
708- f"to prompt for XML continuation"
709- )
731+ if tool_call_format in ("xml" , "auto" ) and tool_start :
732+ tool_prompt = tool_prompt + tool_start + "\n "
710733
711734 gen_request_id = gen .get ("request_id" )
712735 tool_request_id = f"{ gen_request_id } -tool"
@@ -715,7 +738,7 @@ async def generate_tool_calls(
715738 asyncio .create_task (
716739 model .container .generate (
717740 tool_request_id ,
718- prompt ,
741+ tool_prompt ,
719742 tool_data ,
720743 mm_embeddings = embeddings ,
721744 )
@@ -734,10 +757,6 @@ async def generate_tool_calls(
734757 if tool_call_format in ("xml" , "auto" ):
735758 # Prepend tool_start to reconstruct complete XML for parser
736759 raw_text = tool_start + "\n " + raw_text
737- logger .debug (
738- f"generate_tool_calls: Raw XML tool call output "
739- f"({ len (raw_text )} chars): { raw_text [:500 ]} ..."
740- )
741760
742761 generations [gen_idx ]["tool_calls" ] = raw_text
743762
0 commit comments