diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 3a6c36624d..f16655a8c8 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1308,6 +1308,142 @@ def _build_tool_call_from_json_dict( return tool_call +# DeepSeek models may emit tool calls as inline text using proprietary +# special tokens. See https://api-docs.deepseek.com/guides/function_calling +# for the full specification. LiteLLM usually translates these into +# structured `tool_calls` but when it doesn't (intermittent), the raw +# tokens land in the `content` field and must be parsed here. +_DS_TCALLS_BEGIN = "\u003c\uff5ctool\u2581calls\u2581begin\uff5c\u003e" +_DS_TCALLS_END = "\u003c\uff5ctool\u2581calls\u2581end\uff5c\u003e" +_DS_TCALL_BEGIN = "\u003c\uff5ctool\u2581call\u2581begin\uff5c\u003e" +_DS_TCALL_END = "\u003c\uff5ctool\u2581call\u2581end\uff5c\u003e" +_DS_TSEP = "\u003c\uff5ctool\u2581sep\uff5c\u003e" + +# Pattern: <|tool▁call▁begin|>function<|tool▁sep|>NAME \n ARGS <|tool▁call▁end|> +_DS_TOOL_CALL_RE = re.compile( + re.escape(_DS_TCALL_BEGIN) + + r"function" + + re.escape(_DS_TSEP) + + r"([^\n\r]+?)\s*?\n(.*?)" + + re.escape(_DS_TCALL_END), + re.DOTALL, +) + + +def _extract_json_from_deepseek_args(args_text: str) -> Optional[str]: + """Extracts a JSON string from DeepSeek arguments text. + + Args: + args_text: Raw text containing the function arguments, possibly + wrapped in Markdown-style code fences. + + Returns: + The JSON string, or None if no valid JSON object could be found. + """ + if not args_text: + return None + # Strip optional Markdown code fences (```json ... ``` or ``` ... ```). + fence_match = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", args_text) + if fence_match: + return fence_match.group(1).strip() + # Fall back to the first balanced { … } block. + open_brace = args_text.find("{") + if open_brace == -1: + return None + try: + candidate, _ = _JSON_DECODER.raw_decode(args_text, open_brace) + return json.dumps(candidate, ensure_ascii=False) + except json.JSONDecodeError: + return None + + +def _parse_deepseek_tool_calls_from_text( + text_block: str, +) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: + """Parses DeepSeek proprietary inline tool-call tokens from text. + + When LiteLLM does not translate DeepSeek's special tokens into + structured ``tool_calls``, the raw tokens appear inside the ``content`` + field. This function extracts them and returns standard + ``ChatCompletionMessageToolCall`` objects. + + Token reference + ``<|tool▁calls▁begin|>`` … ``<|tool▁calls▁end|>`` → outer wrapper + ``<|tool▁call▁begin|>function<|tool▁sep|>NAME`` → single call start + ``<|tool▁call▁end|>`` → single call end + + Args: + text_block: The raw text that may contain DeepSeek tokens. + + Returns: + A tuple of ``(tool_calls, remainder)`` where ``remainder`` is the + original text with all DeepSeek token regions removed. + """ + _ensure_litellm_imported() + + tool_calls: list[ChatCompletionMessageToolCall] = [] + if not text_block: + return tool_calls, None + + # Quick guard: only invoke the regex if the outer tokens are present. + if _DS_TCALLS_BEGIN not in text_block and _DS_TCALL_BEGIN not in text_block: + return tool_calls, None + + remainder_parts: list[str] = [] + cursor = 0 + + # Outer loop — there may be multiple <|tool▁calls▁begin|> blocks. + while True: + begin_idx = text_block.find(_DS_TCALLS_BEGIN, cursor) + if begin_idx == -1: + # No more wrapped blocks; also look for unwrapped top-level call tokens. + begin_idx = text_block.find(_DS_TCALL_BEGIN, cursor) + if begin_idx == -1: + remainder_parts.append(text_block[cursor:]) + break + + # Everything before the token becomes remainder. + if begin_idx > cursor: + remainder_parts.append(text_block[cursor:begin_idx]) + + # Determine whether we are inside a wrapped block. + in_wrapped_block = text_block[begin_idx : begin_idx + len(_DS_TCALLS_BEGIN)] == _DS_TCALLS_BEGIN # pytype: disable=attribute-error # pylint: disable=line-too-long + if in_wrapped_block: + end_idx = text_block.find( + _DS_TCALLS_END, begin_idx + len(_DS_TCALLS_BEGIN) + ) + if end_idx == -1: + remainder_parts.append(text_block[begin_idx:]) + break + block = text_block[begin_idx + len(_DS_TCALLS_BEGIN) : end_idx] + cursor = end_idx + len(_DS_TCALLS_END) + else: + # Unwrapped call token — scan for a matching end token. + end_idx = text_block.find(_DS_TCALL_END, begin_idx + len(_DS_TCALL_BEGIN)) + if end_idx == -1: + remainder_parts.append(text_block[begin_idx:]) + break + block = text_block[begin_idx : end_idx + len(_DS_TCALL_END)] + cursor = end_idx + len(_DS_TCALL_END) + + # Parse individual tool calls inside the block. + for match in _DS_TOOL_CALL_RE.finditer(block): + func_name = match.group(1).strip() + args_raw = match.group(2).strip() + args_json = _extract_json_from_deepseek_args(args_raw) + if not func_name or not args_json: + continue + tool_call = _build_tool_call_from_json_dict( + {"name": func_name, "arguments": args_json}, + index=len(tool_calls), + ) + if tool_call: + tool_calls.append(tool_call) + + remainder = "".join(p for p in remainder_parts if p).strip() + return tool_calls, remainder or None + + def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: @@ -1318,6 +1454,17 @@ def _parse_tool_calls_from_text( _ensure_litellm_imported() + # Try DeepSeek proprietary format first, then fall back to generic JSON. + ds_tool_calls, ds_remainder = _parse_deepseek_tool_calls_from_text(text_block) + if ds_tool_calls: + # If the remainder still contains content, re-parse it for + # additional generic inline JSON tool calls (mixed formats). + if ds_remainder: + extra_calls, extra_remainder = _parse_tool_calls_from_text(ds_remainder) + tool_calls = ds_tool_calls + (extra_calls or []) + return tool_calls, extra_remainder + return ds_tool_calls, None + remainder_segments = [] cursor = 0 text_length = len(text_block) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c195076349..3fee9eac79 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -43,6 +43,7 @@ from google.adk.models.lite_llm import _MISSING_TOOL_RESULT_MESSAGE from google.adk.models.lite_llm import _model_response_to_chunk from google.adk.models.lite_llm import _model_response_to_generate_content_response +from google.adk.models.lite_llm import _parse_deepseek_tool_calls_from_text from google.adk.models.lite_llm import _parse_tool_calls_from_text from google.adk.models.lite_llm import _redirect_litellm_loggers_to_stdout from google.adk.models.lite_llm import _safe_json_serialize @@ -2692,6 +2693,129 @@ def test_parse_tool_calls_from_text_invalid_json_returns_remainder(): assert remainder == 'Leading {"unused": "payload"} trailing text' +# --------------------------------------------------------------------------- +# DeepSeek proprietary inline tool-call format tests +# --------------------------------------------------------------------------- + +_DS_BEGIN_CALLS = "\u003c\uff5ctool\u2581calls\u2581begin\uff5c\u003e" +_DS_END_CALLS = "\u003c\uff5ctool\u2581calls\u2581end\uff5c\u003e" +_DS_BEGIN_CALL = "\u003c\uff5ctool\u2581call\u2581begin\uff5c\u003e" +_DS_END_CALL = "\u003c\uff5ctool\u2581call\u2581end\uff5c\u003e" +_DS_SEP = "\u003c\uff5ctool\u2581sep\uff5c\u003e" + + +def _ds_tool_call(name: str, args_json: str) -> str: + """Build a single DeepSeek-style tool-call block.""" + return ( + f"{_DS_BEGIN_CALL}function{_DS_SEP}{name}\n" + f"```json\n{args_json}\n```" + f"{_DS_END_CALL}" + ) + + +def _ds_wrapped(inner: str) -> str: + """Wrap content in <|tool▁calls▁begin|>...<|tool▁calls▁end|>.""" + return f"{_DS_BEGIN_CALLS}{inner}{_DS_END_CALLS}" + + +def test_parse_deepseek_single_tool_call(): + """Single DeepSeek tool call with code-fenced JSON args.""" + text = _ds_wrapped( + _ds_tool_call("get_weather", '{"city": "Beijing", "unit": "celsius"}') + ) + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "get_weather" + assert json.loads(tool_calls[0].function.arguments) == { + "city": "Beijing", + "unit": "celsius", + } + assert remainder is None + + +def test_parse_deepseek_multi_tool_call(): + """Multiple DeepSeek tool calls in a single wrapped block.""" + inner = _ds_tool_call("func_a", '{"x": 1}') + _ds_tool_call( + "func_b", '{"y": 2}' + ) + text = _ds_wrapped(inner) + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert len(tool_calls) == 2 + assert tool_calls[0].function.name == "func_a" + assert json.loads(tool_calls[0].function.arguments) == {"x": 1} + assert tool_calls[1].function.name == "func_b" + assert json.loads(tool_calls[1].function.arguments) == {"y": 2} + assert remainder is None + + +def test_parse_deepseek_plain_json_args(): + """DeepSeek tool call without Markdown code fences around args.""" + inner = ( + f"{_DS_BEGIN_CALL}function{_DS_SEP}search\n" + f'{{"query": "天气"}}' + f"{_DS_END_CALL}" + ) + text = _ds_wrapped(inner) + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "search" + assert json.loads(tool_calls[0].function.arguments) == {"query": "天气"} + + +def test_parse_deepseek_with_surrounding_text(): + """DeepSeek tool call embedded in surrounding non-tool text.""" + prefix = "Let me think about this.\n" + suffix = "\nI'll proceed now." + inner = _ds_tool_call("calculate", '{"expr": "2+2"}') + text = prefix + _ds_wrapped(inner) + suffix + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "calculate" + assert remainder == "Let me think about this.\n\nI'll proceed now." + + +def test_parse_deepseek_no_tokens_returns_empty(): + """Text without DeepSeek tokens returns no tool calls and None remainder.""" + text = "Just a regular response, no special tokens here." + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert tool_calls == [] + assert remainder is None + + +def test_parse_tool_calls_from_text_handles_deepseek_format(): + """Integration: the generic parser delegates to the DeepSeek parser.""" + text = _ds_wrapped( + _ds_tool_call("fetch_page", '{"url": "https://example.com"}') + ) + tool_calls, remainder = _parse_tool_calls_from_text(text) + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "fetch_page" + assert json.loads(tool_calls[0].function.arguments) == { + "url": "https://example.com" + } + assert remainder is None + + +def test_parse_tool_calls_from_text_mixed_formats(): + """DeepSeek tokens + standard inline JSON in the same text.""" + ds_part = _ds_wrapped(_ds_tool_call("ds_func", '{"a": 1}')) + standard_part = '{"name": "std_func", "arguments": {"b": 2}}' + text = ds_part + " some text " + standard_part + tool_calls, remainder = _parse_tool_calls_from_text(text) + assert len(tool_calls) == 2 + assert tool_calls[0].function.name == "ds_func" + assert tool_calls[1].function.name == "std_func" + assert remainder == "some text" + + +def test_parse_deepseek_empty_text(): + """Empty or whitespace-only text returns no tool calls.""" + for text in ("", " ", "\n\n"): + tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text) + assert tool_calls == [] + assert remainder is None + + def test_split_message_content_and_tool_calls_inline_text(): message = { "role": "assistant",