diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 93b10c1214..c81258f982 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -16,29 +16,30 @@ import importlib import json import os -import cloudpickle -import sys import re -from unittest import mock +import sys from typing import Optional +from unittest import mock +import uuid +import cloudpickle from google import auth +from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials +from google.auth.transport import mtls from google.cloud import storage -import vertexai from google.cloud import aiplatform -from google.cloud.aiplatform_v1 import types as aip_types -from google.cloud.aiplatform_v1.services import reasoning_engine_service +import vertexai from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer -from vertexai.agent_engines import _utils +from google.cloud.aiplatform_v1 import types as aip_types +from google.cloud.aiplatform_v1.services import reasoning_engine_service from vertexai import agent_engines -from vertexai.agent_engines.templates import adk as adk_template from vertexai.agent_engines import _agent_engines -from google.api_core import operation as ga_operation +from vertexai.agent_engines import _utils +from vertexai.agent_engines.templates import adk as adk_template from google.genai import types import pytest -import uuid try: @@ -1066,6 +1067,7 @@ def update_agent_engine_mock(): @pytest.mark.usefixtures("google_auth_mock") class TestAgentEngines: + def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) @@ -1167,3 +1169,260 @@ def test_update_default_telemetry_enablement( assert _utils.to_dict(deployment_spec)["env"] == [ {"name": key, "value": value} for key, value in expected_env_vars.items() ] + + +class TestAdkAppMtls: + """Test cases for mTLS functionality in AdkApp.""" + + def test_use_client_cert_effective_with_should_use_client_cert(self): + """Verifies that it respects the google-auth mTLS enablement check.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=True, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_use_client_cert_effective_with_env_var_true(self): + """Verifies that it falls back to the environment variable if google-auth check fails.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}) + def test_use_client_cert_effective_with_env_var_false(self): + """Verifies that it respects the environment variable being set to false.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_default(self): + """Verifies the default telemetry endpoint is returned when no mTLS is configured.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_with_cert(self): + """Verifies the mTLS endpoint is used when forced and a certificate is available.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) + def test_get_api_endpoint_auto_no_cert(self): + """Verifies it falls back to regular endpoint even if forced if no certificate is provided.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}) + def test_get_api_endpoint_never(self): + """Verifies the regular endpoint is used when mTLS is explicitly disabled.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_with_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS enabled.""" + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify the session was configured for mTLS + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + # Verify the exporter was initialized with the mTLS endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_with_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS enabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was configured for the check request + mock_session.configure_mtls_channel.assert_called_once() + # Verify the check was performed against the mTLS endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "invalid_value"}) + def test_get_api_endpoint_invalid_env(self): + """Verifies it defaults to AUTO and warns on invalid env var.""" + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert ( + adk_template._get_api_endpoint() + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + mock_warn.assert_called_once() + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "not_a_bool"}) + def test_use_client_cert_effective_invalid_env(self): + """Verifies it warns on invalid boolean env var.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert adk_template._use_client_cert_effective() is False + mock_warn.assert_called_once() + + def test_use_client_cert_effective_with_should_use_client_cert_false(self): + """Verifies that it respects google-auth returning False for mTLS.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=False, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_auto_with_cert(self): + """Verifies the mTLS endpoint is used in AUTO mode when a cert is available.""" + # AUTO is the default, so we just pass a cert + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_no_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS disabled.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify mTLS channel was NOT configured + mock_session_cls.return_value.configure_mtls_channel.assert_not_called() + # Verify the exporter was initialized with the regular endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_no_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS disabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was NOT configured + mock_session.configure_mtls_channel.assert_not_called() + # Verify the check was performed against the regular endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_mtls_no_cert_source( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Tests that it falls back to regular endpoint if mTLS is on but no cert is found.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, + "has_default_client_cert_source", + return_value=False, + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Channel is configured, but endpoint remains default due to missing cert source + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 1144e7d096..b5f4c86b83 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -13,24 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +from collections.abc import Awaitable +import enum +import os +import queue +import sys +import threading from typing import ( - TYPE_CHECKING, Any, AsyncIterable, Callable, Dict, List, Optional, + TYPE_CHECKING, Union, ) - -import asyncio -from collections.abc import Awaitable -import queue -import sys -import threading import warnings +import google.auth +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth + if TYPE_CHECKING: try: from google.adk.events.event import Event @@ -95,16 +100,26 @@ _DEFAULT_APP_NAME = "default-app-name" _DEFAULT_USER_ID = "default-user-id" -_TELEMETRY_API_DISABLED_WARNING = ( - "Tracing integration for Agent Engine has migrated to a new API.\n" - "The 'telemetry.googleapis.com' has not been enabled in project %s. \n" - "**Impact:** Until this API is enabled, telemetry data will not be stored." - "\n" - "**Action:** Please enable the API by visiting " - "https://console.developers.google.com/apis/api/telemetry.googleapis.com/overview?project=%s." - "\n" - "(If you enabled this API recently, you can safely ignore this warning.)" -) +_TELEMETRY_API_DISABLED_WARNING = """\ +Tracing integration for Agent Engine has migrated to a new API. +The 'telemetry.googleapis.com' has not been enabled in project %s. +**Impact:** Until this API is enabled, telemetry data will not be stored. + +**Action:** Please enable the API by visiting https://console.developers.google.com/apis/api/telemetry.googleapis.com/overview?project=%s. + +(If you enabled this API recently, you can safely ignore this warning.) +""" + +_DEFAULT_TELEMETRY_ENDPOINT = "https://telemetry.googleapis.com/v1/traces" +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = "https://telemetry.mtls.googleapis.com/v1/traces" + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" def get_adk_version() -> Optional[str]: @@ -293,7 +308,8 @@ def _default_instrumentor_builder( if project_id is None: _warn( - "telemetry is only supported when project is specified, proceeding with no telemetry" + "telemetry is only supported when project is specified, proceeding with" + " no telemetry" ) return None @@ -306,10 +322,19 @@ def _warn_missing_dependency( needed_for_tracing: bool = False, ) -> None: _warn( - f"{package} is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + f"{package} is not installed. Please call 'pip install" + " google-cloud-aiplatform[agent_engines]'." + ) + MISSING_TRACE_IMPORT_ERROR_MESSAGE = ( + "proceeding with tracing disabled because not all packages (i.e." + " `google-cloud-trace`, `opentelemetry-sdk`," + " `opentelemetry-exporter-gcp-trace`) for tracing have been installed" + ) + MISSING_LOGGING_IMPORT_ERROR_MESSAGE = ( + "proceeding with logging disabled because not all packages (i.e." + " `google-cloud-logging`, `opentelemetry-sdk`," + " `opentelemetry-exporter-gcp-logging`) for tracing have been installed" ) - MISSING_TRACE_IMPORT_ERROR_MESSAGE = "proceeding with tracing disabled because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-trace`) for tracing have been installed" - MISSING_LOGGING_IMPORT_ERROR_MESSAGE = "proceeding with logging disabled because not all packages (i.e. `google-cloud-logging`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-logging`) for tracing have been installed" if needed_for_tracing and enable_tracing: _warn(MISSING_TRACE_IMPORT_ERROR_MESSAGE) @@ -389,14 +414,29 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: credentials, _ = google.auth.default() vertex_sdk_version = aip_version.__version__ otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__ - user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}" + user_agent = ( + f"Vertex-Agent-Engine/{vertex_sdk_version}" + f" OTel-OTLP-Exporter-Python/{otlp_http_version}" + ) + + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT span_exporter = ( opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( - session=google.auth.transport.requests.AuthorizedSession( - credentials=credentials - ), - endpoint="https://telemetry.googleapis.com/v1/traces", + session=session, + endpoint=endpoint, headers={"User-Agent": user_agent}, ) ) @@ -446,6 +486,7 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: class _SimpleLogRecordProcessor( opentelemetry.sdk._logs.export.SimpleLogRecordProcessor ): + def force_flush( self, timeout_millis: int = 30000 ) -> bool: # pylint: disable=no-self-use @@ -499,7 +540,9 @@ def force_flush( google_genai.GoogleGenAiSdkInstrumentor().instrument() except (ImportError, AttributeError): _warn( - "telemetry enabled but proceeding without GenAI instrumentation, because not all packages (i.e. opentelemetry-instrumentation-google-genai) have been installed" + "telemetry enabled but proceeding without GenAI instrumentation," + " because not all packages (i.e." + " opentelemetry-instrumentation-google-genai) have been installed" ) return None @@ -546,18 +589,83 @@ def _validate_run_config(run_config: Optional[Dict[str, Any]]): def _warn_if_telemetry_api_disabled(): """Warn if telemetry API is disabled.""" - try: - import google.auth.transport.requests - import google.auth - except (ImportError, AttributeError): - return credentials, project = google.auth.default() - session = google.auth.transport.requests.AuthorizedSession(credentials=credentials) - r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + r = session.post(endpoint, data=None) if "Telemetry API has not been used in project" in r.text: _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) +def _get_api_endpoint(client_cert_source: bytes | None = None) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (bytes | None): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + _warn( + f"Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of " + f"{[e.value for e in MtlsEndpoint]}. Defaulting to" + f" {MtlsEndpoint.AUTO.value}." + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if (use_mtls_endpoint == MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint == MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + # check if google-auth version supports should_use_client_cert for automatic + # mTLS enablement + try: + return mtls.should_use_client_cert() + except (ImportError, AttributeError): + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + _warn( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + + class AdkApp: """An ADK Application."""