From 4b7f6c43f68ccc9b277b6dcf33da0fd4b738a3de Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Fri, 8 May 2026 09:29:05 -0700 Subject: [PATCH] feat: Add MCP support to async generate_content PiperOrigin-RevId: 912562893 --- google/genai/_adapters.py | 2 + google/genai/_extra_utils.py | 45 ++-- google/genai/_live_converters.py | 30 ++- google/genai/caches.py | 30 ++- google/genai/models.py | 245 ++++++++++++------ .../models/test_generate_content_tools.py | 64 +++++ 6 files changed, 309 insertions(+), 107 deletions(-) diff --git a/google/genai/_adapters.py b/google/genai/_adapters.py index d1dddb2b2..73105f6e3 100644 --- a/google/genai/_adapters.py +++ b/google/genai/_adapters.py @@ -30,9 +30,11 @@ def __init__( self, session: "mcp.ClientSession", # type: ignore # noqa: F821 list_tools_result: "mcp_types.ListToolsResult", # type: ignore + is_agent_platform: bool = False, ) -> None: self._mcp_session = session self._list_tools_result = list_tools_result + self._is_agent_platform = is_agent_platform async def call_tool( self, function_call: FunctionCall diff --git a/google/genai/_extra_utils.py b/google/genai/_extra_utils.py index 129c05f7d..52e03036b 100644 --- a/google/genai/_extra_utils.py +++ b/google/genai/_extra_utils.py @@ -120,18 +120,9 @@ def format_destination( def find_afc_incompatible_tool_indexes( config: Optional[types.GenerateContentConfigOrDict] = None, + is_agent_platform: bool = False, ) -> list[int]: - """Checks if the config contains any AFC incompatible tools. - - A `types.Tool` object that contains `function_declarations` is considered a - non-AFC tool for this execution path. - - Args: - config: The GenerateContentConfig to check for incompatible tools. - - Returns: - A list of indexes of the incompatible tools in the config. - """ + """Checks if the config contains any AFC incompatible tools.""" if not config: return [] config_model = _create_generate_content_config_model(config) @@ -145,7 +136,9 @@ def find_afc_incompatible_tool_indexes( continue if tool.function_declarations: incompatible_tools_indexes.append(index) - if tool.mcp_servers: + + # Only mark it incompatible if it's MLDev, not Agent Platform. + if tool.mcp_servers and not is_agent_platform: incompatible_tools_indexes.append(index) return incompatible_tools_indexes @@ -383,12 +376,15 @@ async def get_function_response_parts_async( if not part.function_call: continue func_name = part.function_call.name - if func_name is not None and part.function_call.args is not None: + if func_name is not None: func = function_map[func_name] - args = convert_number_values_for_dict_function_call_args( + # Treat None as an empty dictionary for execution + raw_args = ( part.function_call.args + if part.function_call.args is not None + else {} ) - func_response: _common.StringDict + args = convert_number_values_for_dict_function_call_args(raw_args) try: if isinstance(func, McpToGenAiToolAdapter): mcp_tool_response = await func.call_tool( @@ -551,6 +547,7 @@ def parse_config_for_mcp_usage( async def parse_config_for_mcp_sessions( config: Optional[types.GenerateContentConfigOrDict] = None, + is_agent_platform: bool = False, ) -> tuple[ Optional[types.GenerateContentConfig], dict[str, McpToGenAiToolAdapter], @@ -571,7 +568,7 @@ async def parse_config_for_mcp_sessions( for tool in parsed_config.tools: if McpClientSession is not None and isinstance(tool, McpClientSession): mcp_to_genai_tool_adapter = McpToGenAiToolAdapter( - tool, await tool.list_tools() + tool, await tool.list_tools(), is_agent_platform=is_agent_platform ) # Extend the config with the MCP session tools converted to GenAI tools. parsed_config_copy.tools.extend(mcp_to_genai_tool_adapter.tools) @@ -677,3 +674,19 @@ def prepare_resumable_upload( http_options.headers = {} http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file) return http_options, size_bytes, mime_type + + +def has_agent_platform_mcp_servers( + config: Optional[types.GenerateContentConfigOrDict] = None, +) -> bool: + """Checks whether the configuration contains any MCP server requests.""" + if not config: + return False + config_model = _create_generate_content_config_model(config) + if not config_model.tools: + return False + + for tool in config_model.tools: + if isinstance(tool, types.Tool) and tool.mcp_servers: + return True + return False diff --git a/google/genai/_live_converters.py b/google/genai/_live_converters.py index dae6866d3..eadc044a1 100644 --- a/google/genai/_live_converters.py +++ b/google/genai/_live_converters.py @@ -1483,6 +1483,26 @@ def _LiveServerMessage_from_vertex( return to_object +def _McpServer_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['name']) is not None: + raise ValueError( + 'name parameter is only supported in Gemini Developer API mode, not in' + ' Gemini Enterprise Agent Platform mode.' + ) + + if getv(from_object, ['streamable_http_transport']) is not None: + raise ValueError( + 'streamable_http_transport parameter is only supported in Gemini' + ' Developer API mode, not in Gemini Enterprise Agent Platform mode.' + ) + + return to_object + + def _MultiSpeakerVoiceConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -1893,9 +1913,13 @@ def _Tool_to_vertex( setv(to_object, ['urlContext'], getv(from_object, ['url_context'])) if getv(from_object, ['mcp_servers']) is not None: - raise ValueError( - 'mcp_servers parameter is only supported in Gemini Developer API mode,' - ' not in Gemini Enterprise Agent Platform mode.' + setv( + to_object, + ['mcpServers'], + [ + _McpServer_to_vertex(item, to_object) + for item in getv(from_object, ['mcp_servers']) + ], ) return to_object diff --git a/google/genai/caches.py b/google/genai/caches.py index c26a959c6..ad95dd24b 100644 --- a/google/genai/caches.py +++ b/google/genai/caches.py @@ -625,6 +625,26 @@ def _ListCachedContentsResponse_from_vertex( return to_object +def _McpServer_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['name']) is not None: + raise ValueError( + 'name parameter is only supported in Gemini Developer API mode, not in' + ' Gemini Enterprise Agent Platform mode.' + ) + + if getv(from_object, ['streamable_http_transport']) is not None: + raise ValueError( + 'streamable_http_transport parameter is only supported in Gemini' + ' Developer API mode, not in Gemini Enterprise Agent Platform mode.' + ) + + return to_object + + def _Part_to_mldev( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -961,9 +981,13 @@ def _Tool_to_vertex( setv(to_object, ['urlContext'], getv(from_object, ['url_context'])) if getv(from_object, ['mcp_servers']) is not None: - raise ValueError( - 'mcp_servers parameter is only supported in Gemini Developer API mode,' - ' not in Gemini Enterprise Agent Platform mode.' + setv( + to_object, + ['mcpServers'], + [ + _McpServer_to_vertex(item, to_object) + for item in getv(from_object, ['mcp_servers']) + ], ) return to_object diff --git a/google/genai/models.py b/google/genai/models.py index c004c1dfa..8c317b454 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -15,6 +15,7 @@ # Code generated by the Google Gen AI SDK generator DO NOT EDIT. +import contextlib import json import logging from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union @@ -3341,6 +3342,27 @@ def _MaskReferenceConfig_to_vertex( return to_object +def _McpServer_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + root_object: Optional[Union[dict[str, Any], object]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['name']) is not None: + raise ValueError( + 'name parameter is only supported in Gemini Developer API mode, not in' + ' Gemini Enterprise Agent Platform mode.' + ) + + if getv(from_object, ['streamable_http_transport']) is not None: + raise ValueError( + 'streamable_http_transport parameter is only supported in Gemini' + ' Developer API mode, not in Gemini Enterprise Agent Platform mode.' + ) + + return to_object + + def _Model_from_mldev( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -4350,9 +4372,13 @@ def _Tool_to_vertex( setv(to_object, ['urlContext'], getv(from_object, ['url_context'])) if getv(from_object, ['mcp_servers']) is not None: - raise ValueError( - 'mcp_servers parameter is only supported in Gemini Developer API mode,' - ' not in Gemini Enterprise Agent Platform mode.' + setv( + to_object, + ['mcpServers'], + [ + _McpServer_to_vertex(item, to_object, root_object) + for item in getv(from_object, ['mcp_servers']) + ], ) return to_object @@ -8452,96 +8478,141 @@ async def generate_content( print(response.text) # J'aime les bagels. """ - # Retrieve and cache any MCP sessions if provided. incompatible_tools_indexes = ( - _extra_utils.find_afc_incompatible_tool_indexes(config) - ) - parsed_config, mcp_to_genai_tool_adapters = ( - await _extra_utils.parse_config_for_mcp_sessions(config) - ) - if _extra_utils.should_disable_afc(parsed_config): - return await self._generate_content( - model=model, contents=contents, config=parsed_config - ) - if incompatible_tools_indexes: - original_tools_length = 0 - if isinstance(config, types.GenerateContentConfig): - if config.tools: - original_tools_length = len(config.tools) - elif isinstance(config, dict): - tools = config.get('tools', []) - if tools: - original_tools_length = len(tools) - if len(incompatible_tools_indexes) != original_tools_length: - indices_str = ', '.join(map(str, incompatible_tools_indexes)) - logger.warning( - 'Tools at indices [%s] are not compatible with automatic function ' - 'calling (AFC). AFC is disabled. If AFC is intended, please ' - 'include python callables in the tool list, and do not include ' - 'function declaration and MCP server in the tool list.', - indices_str, + _extra_utils.find_afc_incompatible_tool_indexes( + config, is_agent_platform=bool(self._api_client.vertexai) ) - return await self._generate_content( - model=model, contents=contents, config=parsed_config - ) - remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc( - parsed_config - ) - logger.info( - f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' ) - automatic_function_calling_history: list[types.Content] = [] - response = types.GenerateContentResponse() - while remaining_remote_calls_afc > 0: - response = await self._generate_content( - model=model, contents=contents, config=parsed_config - ) - remaining_remote_calls_afc -= 1 - if remaining_remote_calls_afc == 0: - logger.info('Reached max remote calls for automatic function calling.') - function_map = _extra_utils.get_function_map( - parsed_config, mcp_to_genai_tool_adapters, is_caller_method_async=True - ) - if not function_map: - break - if not response: - break + if config is None: + parsed_config = types.GenerateContentConfig() + elif isinstance(config, dict): + parsed_config = types.GenerateContentConfig(**config) + else: + parsed_config = config.model_copy(deep=True) + + # Use AsyncExitStack to keep MCP connections alive across the entire AFC loop + async with contextlib.AsyncExitStack() as stack: + + # Intercept Vertex MCP servers and open connections if ( - not response.candidates - or not response.candidates[0].content - or not response.candidates[0].content.parts + self._api_client.vertexai + and _extra_utils.has_agent_platform_mcp_servers(parsed_config) ): - break - func_response_parts = ( - await _extra_utils.get_function_response_parts_async( - response, function_map + new_tools = [] + if parsed_config.tools: + for tool in parsed_config.tools: + if isinstance(tool, types.Tool) and tool.mcp_servers: + for server in tool.mcp_servers: + # Open the stream and tie its lifespan to the AsyncExitStack + session = await stack.enter_async_context( + _mcp_utils._connect_agent_platform_mcp( + self._api_client, server.name + ) + ) + new_tools.append(session) + else: + new_tools.append(tool) + parsed_config.tools = new_tools + + # Convert active sessions to tools and adapters + final_parsed_config, mcp_to_genai_tool_adapters = ( + await _extra_utils.parse_config_for_mcp_sessions( + parsed_config, is_agent_platform=bool(self._api_client.vertexai) ) ) - if not func_response_parts: - break - func_call_content = response.candidates[0].content - func_response_content = types.Content( - role='user', - parts=func_response_parts, - ) - contents = t.t_contents(contents) # type: ignore[assignment] - if not automatic_function_calling_history: - automatic_function_calling_history.extend(contents) # type: ignore[arg-type] - if isinstance(contents, list): - contents.append(func_call_content) # type: ignore[arg-type] - contents.append(func_response_content) # type: ignore[arg-type] - automatic_function_calling_history.append(func_call_content) - automatic_function_calling_history.append(func_response_content) - if ( - _extra_utils.should_append_afc_history(parsed_config) - and response is not None - ): - response.automatic_function_calling_history = ( - automatic_function_calling_history + if _extra_utils.should_disable_afc(final_parsed_config): + return await self._generate_content( + model=model, contents=contents, config=final_parsed_config + ) + + if incompatible_tools_indexes: + original_tools_length = 0 + if isinstance(config, types.GenerateContentConfig): + if config.tools: + original_tools_length = len(config.tools) + elif isinstance(config, dict): + tools = config.get('tools', []) + if tools: + original_tools_length = len(tools) + if len(incompatible_tools_indexes) != original_tools_length: + indices_str = ', '.join(map(str, incompatible_tools_indexes)) + logger.warning( + 'Tools at indices [%s] are not compatible with automatic function' + ' calling (AFC). AFC is disabled. If AFC is intended, please' + ' include python callables in the tool list, and do not include' + ' function declaration and MCP server in the tool list.', + indices_str, + ) + return await self._generate_content( + model=model, contents=contents, config=final_parsed_config + ) + + remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc( + final_parsed_config ) - return response + logger.info( + f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' + ) + automatic_function_calling_history: list[types.Content] = [] + response = types.GenerateContentResponse() + + while remaining_remote_calls_afc > 0: + response = await self._generate_content( + model=model, contents=contents, config=final_parsed_config + ) + remaining_remote_calls_afc -= 1 + if remaining_remote_calls_afc == 0: + logger.info( + 'Reached max remote calls for automatic function calling.' + ) + + function_map = _extra_utils.get_function_map( + final_parsed_config, + mcp_to_genai_tool_adapters, + is_caller_method_async=True, + ) + if not function_map: + break + if not response: + break + if ( + not response.candidates + or not response.candidates[0].content + or not response.candidates[0].content.parts + ): + break + func_response_parts = ( + await _extra_utils.get_function_response_parts_async( + response, function_map + ) + ) + if not func_response_parts: + break + func_call_content = response.candidates[0].content + func_response_content = types.Content( + role='user', + parts=func_response_parts, + ) + contents = t.t_contents(contents) # type: ignore[assignment] + if not automatic_function_calling_history: + automatic_function_calling_history.extend(contents) # type: ignore[arg-type] + if isinstance(contents, list): + contents.append(func_call_content) # type: ignore[arg-type] + contents.append(func_response_content) # type: ignore[arg-type] + automatic_function_calling_history.append(func_call_content) + automatic_function_calling_history.append(func_response_content) + + if ( + _extra_utils.should_append_afc_history(final_parsed_config) + and response is not None + ): + response.automatic_function_calling_history = ( + automatic_function_calling_history + ) + + return response async def generate_content_stream( self, @@ -8611,11 +8682,15 @@ async def generate_content_stream( # Retrieve and cache any MCP sessions if provided. incompatible_tools_indexes = ( - _extra_utils.find_afc_incompatible_tool_indexes(config) + _extra_utils.find_afc_incompatible_tool_indexes( + config, is_agent_platform=bool(self._api_client.vertexai) + ) ) # Retrieve and cache any MCP sessions if provided. parsed_config, mcp_to_genai_tool_adapters = ( - await _extra_utils.parse_config_for_mcp_sessions(config) + await _extra_utils.parse_config_for_mcp_sessions( + config, is_agent_platform=bool(self._api_client.vertexai) + ) ) if _extra_utils.should_disable_afc(parsed_config): response = await self._generate_content_stream( diff --git a/google/genai/tests/models/test_generate_content_tools.py b/google/genai/tests/models/test_generate_content_tools.py index 0693ad621..c751508e3 100644 --- a/google/genai/tests/models/test_generate_content_tools.py +++ b/google/genai/tests/models/test_generate_content_tools.py @@ -27,6 +27,19 @@ from ... import types from .. import pytest_helper +import contextlib +from unittest import mock +import pytest +from google.genai import types + +try: + from mcp import types as mcp_types + from mcp import ClientSession +except ImportError: + mcp_types = None + ClientSession = None + + GOOGLE_HOMEPAGE_FILE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), '../data/google_homepage.png') ) @@ -1911,3 +1924,54 @@ def test_server_side_mcp_only_stream(client): ) for chunk in response: pass + + +@pytest.mark.asyncio +async def test_client_side_mcp_unary_async(client): + """Test client-side MCP execution for Vertex AI.""" + + if not client._api_client.vertexai: + pytest.skip("Vertex MCP test is not applicable to MLDev.") + + if mcp_types is None: + pytest.skip("MCP library is not installed.") + + mock_session = mock.AsyncMock(spec=ClientSession) + mock_session.list_tools.return_value = mcp_types.ListToolsResult( + tools=[ + mcp_types.Tool( + name="list_endpoints", + description="Lists endpoints", + inputSchema={"type": "object", "properties": {}} + ) + ] + ) + + mock_session.call_tool.return_value = mcp_types.CallToolResult( + content=[mcp_types.TextContent(type="text", text="Endpoint list: [my-endpoint-123]")] + ) + + @contextlib.asynccontextmanager + async def mock_connect(*args, **kwargs): + yield mock_session + + # Patch the connection helper to bypass live HTTP requests to the MCP server + with mock.patch('google.genai._mcp_utils._connect_agent_platform_mcp', side_effect=mock_connect): + + response = await client.aio.models.generate_content( + model='gemini-2.5-flash', + contents='List my endpoints in us-central1.', + config={ + 'tools': [ + types.Tool( + mcp_servers=[types.McpServer(name='endpoints')] + ) + ], + 'automatic_function_calling': {'disable': False} + } + ) + + # 6. Verify the AFC loop successfully ran the mock tool + assert response.text is not None + assert mock_session.list_tools.called + assert mock_session.call_tool.called