From b055bef6d16987e12f1649e4042b4aa65d5e1ae4 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 28 Apr 2026 16:12:32 -0700 Subject: [PATCH] feat: Add mTLS support for telemetry endpoint in adk.py. This change enables the telemetry exporter to use mTLS endpoints when configured, by dynamically determining the correct endpoint and configuring the requests session accordingly. It introduces helper functions to handle client certificate source management. PiperOrigin-RevId: 907230510 --- .../test_agent_engine_templates_adk.py | 279 +++++++++++++++++- vertexai/agent_engines/templates/adk.py | 176 ++++++++--- 2 files changed, 411 insertions(+), 44 deletions(-) diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 93b10c1214..c81258f982 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -16,29 +16,30 @@ import importlib import json import os -import cloudpickle -import sys import re -from unittest import mock +import sys from typing import Optional +from unittest import mock +import uuid +import cloudpickle from google import auth +from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials +from google.auth.transport import mtls from google.cloud import storage -import vertexai from google.cloud import aiplatform -from google.cloud.aiplatform_v1 import types as aip_types -from google.cloud.aiplatform_v1.services import reasoning_engine_service +import vertexai from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer -from vertexai.agent_engines import _utils +from google.cloud.aiplatform_v1 import types as aip_types +from google.cloud.aiplatform_v1.services import reasoning_engine_service from vertexai import agent_engines -from vertexai.agent_engines.templates import adk as adk_template from vertexai.agent_engines import _agent_engines -from google.api_core import operation as ga_operation +from vertexai.agent_engines import _utils +from vertexai.agent_engines.templates import adk as adk_template from google.genai import types import pytest -import uuid try: @@ -1066,6 +1067,7 @@ def update_agent_engine_mock(): @pytest.mark.usefixtures("google_auth_mock") class TestAgentEngines: + def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) @@ -1167,3 +1169,260 @@ def test_update_default_telemetry_enablement( assert _utils.to_dict(deployment_spec)["env"] == [ {"name": key, "value": value} for key, value in expected_env_vars.items() ] + + +class TestAdkAppMtls: + """Test cases for mTLS functionality in AdkApp.""" + + def test_use_client_cert_effective_with_should_use_client_cert(self): + """Verifies that it respects the google-auth mTLS enablement check.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=True, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_use_client_cert_effective_with_env_var_true(self): + """Verifies that it falls back to the environment variable if google-auth check fails.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}) + def test_use_client_cert_effective_with_env_var_false(self): + """Verifies that it respects the environment variable being set to false.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_default(self): + """Verifies the default telemetry endpoint is returned when no mTLS is configured.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_with_cert(self): + """Verifies the mTLS endpoint is used when forced and a certificate is available.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) + def test_get_api_endpoint_auto_no_cert(self): + """Verifies it falls back to regular endpoint even if forced if no certificate is provided.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}) + def test_get_api_endpoint_never(self): + """Verifies the regular endpoint is used when mTLS is explicitly disabled.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_with_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS enabled.""" + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify the session was configured for mTLS + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + # Verify the exporter was initialized with the mTLS endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_with_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS enabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was configured for the check request + mock_session.configure_mtls_channel.assert_called_once() + # Verify the check was performed against the mTLS endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "invalid_value"}) + def test_get_api_endpoint_invalid_env(self): + """Verifies it defaults to AUTO and warns on invalid env var.""" + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert ( + adk_template._get_api_endpoint() + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + mock_warn.assert_called_once() + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "not_a_bool"}) + def test_use_client_cert_effective_invalid_env(self): + """Verifies it warns on invalid boolean env var.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert adk_template._use_client_cert_effective() is False + mock_warn.assert_called_once() + + def test_use_client_cert_effective_with_should_use_client_cert_false(self): + """Verifies that it respects google-auth returning False for mTLS.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=False, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_auto_with_cert(self): + """Verifies the mTLS endpoint is used in AUTO mode when a cert is available.""" + # AUTO is the default, so we just pass a cert + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_no_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS disabled.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify mTLS channel was NOT configured + mock_session_cls.return_value.configure_mtls_channel.assert_not_called() + # Verify the exporter was initialized with the regular endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_no_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS disabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was NOT configured + mock_session.configure_mtls_channel.assert_not_called() + # Verify the check was performed against the regular endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_mtls_no_cert_source( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Tests that it falls back to regular endpoint if mTLS is on but no cert is found.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, + "has_default_client_cert_source", + return_value=False, + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Channel is configured, but endpoint remains default due to missing cert source + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 1144e7d096..b5f4c86b83 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -13,24 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +from collections.abc import Awaitable +import enum +import os +import queue +import sys +import threading from typing import ( - TYPE_CHECKING, Any, AsyncIterable, Callable, Dict, List, Optional, + TYPE_CHECKING, Union, ) - -import asyncio -from collections.abc import Awaitable -import queue -import sys -import threading import warnings +import google.auth +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth + if TYPE_CHECKING: try: from google.adk.events.event import Event @@ -95,16 +100,26 @@ _DEFAULT_APP_NAME = "default-app-name" _DEFAULT_USER_ID = "default-user-id" -_TELEMETRY_API_DISABLED_WARNING = ( - "Tracing integration for Agent Engine has migrated to a new API.\n" - "The 'telemetry.googleapis.com' has not been enabled in project %s. \n" - "**Impact:** Until this API is enabled, telemetry data will not be stored." - "\n" - "**Action:** Please enable the API by visiting " - "https://console.developers.google.com/apis/api/telemetry.googleapis.com/overview?project=%s." - "\n" - "(If you enabled this API recently, you can safely ignore this warning.)" -) +_TELEMETRY_API_DISABLED_WARNING = """\ +Tracing integration for Agent Engine has migrated to a new API. +The 'telemetry.googleapis.com' has not been enabled in project %s. +**Impact:** Until this API is enabled, telemetry data will not be stored. + +**Action:** Please enable the API by visiting https://console.developers.google.com/apis/api/telemetry.googleapis.com/overview?project=%s. + +(If you enabled this API recently, you can safely ignore this warning.) +""" + +_DEFAULT_TELEMETRY_ENDPOINT = "https://telemetry.googleapis.com/v1/traces" +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = "https://telemetry.mtls.googleapis.com/v1/traces" + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" def get_adk_version() -> Optional[str]: @@ -293,7 +308,8 @@ def _default_instrumentor_builder( if project_id is None: _warn( - "telemetry is only supported when project is specified, proceeding with no telemetry" + "telemetry is only supported when project is specified, proceeding with" + " no telemetry" ) return None @@ -306,10 +322,19 @@ def _warn_missing_dependency( needed_for_tracing: bool = False, ) -> None: _warn( - f"{package} is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + f"{package} is not installed. Please call 'pip install" + " google-cloud-aiplatform[agent_engines]'." + ) + MISSING_TRACE_IMPORT_ERROR_MESSAGE = ( + "proceeding with tracing disabled because not all packages (i.e." + " `google-cloud-trace`, `opentelemetry-sdk`," + " `opentelemetry-exporter-gcp-trace`) for tracing have been installed" + ) + MISSING_LOGGING_IMPORT_ERROR_MESSAGE = ( + "proceeding with logging disabled because not all packages (i.e." + " `google-cloud-logging`, `opentelemetry-sdk`," + " `opentelemetry-exporter-gcp-logging`) for tracing have been installed" ) - MISSING_TRACE_IMPORT_ERROR_MESSAGE = "proceeding with tracing disabled because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-trace`) for tracing have been installed" - MISSING_LOGGING_IMPORT_ERROR_MESSAGE = "proceeding with logging disabled because not all packages (i.e. `google-cloud-logging`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-logging`) for tracing have been installed" if needed_for_tracing and enable_tracing: _warn(MISSING_TRACE_IMPORT_ERROR_MESSAGE) @@ -389,14 +414,29 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: credentials, _ = google.auth.default() vertex_sdk_version = aip_version.__version__ otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__ - user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}" + user_agent = ( + f"Vertex-Agent-Engine/{vertex_sdk_version}" + f" OTel-OTLP-Exporter-Python/{otlp_http_version}" + ) + + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT span_exporter = ( opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( - session=google.auth.transport.requests.AuthorizedSession( - credentials=credentials - ), - endpoint="https://telemetry.googleapis.com/v1/traces", + session=session, + endpoint=endpoint, headers={"User-Agent": user_agent}, ) ) @@ -446,6 +486,7 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: class _SimpleLogRecordProcessor( opentelemetry.sdk._logs.export.SimpleLogRecordProcessor ): + def force_flush( self, timeout_millis: int = 30000 ) -> bool: # pylint: disable=no-self-use @@ -499,7 +540,9 @@ def force_flush( google_genai.GoogleGenAiSdkInstrumentor().instrument() except (ImportError, AttributeError): _warn( - "telemetry enabled but proceeding without GenAI instrumentation, because not all packages (i.e. opentelemetry-instrumentation-google-genai) have been installed" + "telemetry enabled but proceeding without GenAI instrumentation," + " because not all packages (i.e." + " opentelemetry-instrumentation-google-genai) have been installed" ) return None @@ -546,18 +589,83 @@ def _validate_run_config(run_config: Optional[Dict[str, Any]]): def _warn_if_telemetry_api_disabled(): """Warn if telemetry API is disabled.""" - try: - import google.auth.transport.requests - import google.auth - except (ImportError, AttributeError): - return credentials, project = google.auth.default() - session = google.auth.transport.requests.AuthorizedSession(credentials=credentials) - r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + r = session.post(endpoint, data=None) if "Telemetry API has not been used in project" in r.text: _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) +def _get_api_endpoint(client_cert_source: bytes | None = None) -> str: + """Returns API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (bytes | None): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + _warn( + f"Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of " + f"{[e.value for e in MtlsEndpoint]}. Defaulting to" + f" {MtlsEndpoint.AUTO.value}." + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if (use_mtls_endpoint == MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint == MtlsEndpoint.AUTO and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + # check if google-auth version supports should_use_client_cert for automatic + # mTLS enablement + try: + return mtls.should_use_client_cert() + except (ImportError, AttributeError): + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + _warn( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + + class AdkApp: """An ADK Application."""