diff --git a/packages/google-auth/google/auth/environment_vars.py b/packages/google-auth/google/auth/environment_vars.py index c7d706467ed4..c622f1773531 100644 --- a/packages/google-auth/google/auth/environment_vars.py +++ b/packages/google-auth/google/auth/environment_vars.py @@ -129,3 +129,6 @@ "GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES" ) """Environment variable to prevent agent token sharing for GCP services.""" + +GOOGLE_API_USE_MTLS_ENDPOINT = "GOOGLE_API_USE_MTLS_ENDPOINT" +"""Environment variable controlling whether to use mTLS endpoint or not.""" diff --git a/packages/google-auth/google/auth/transport/mtls.py b/packages/google-auth/google/auth/transport/mtls.py index 666a6ca1fd91..c3c1cf186c2b 100644 --- a/packages/google-auth/google/auth/transport/mtls.py +++ b/packages/google-auth/google/auth/transport/mtls.py @@ -14,12 +14,19 @@ """Utilites for mutual TLS.""" +import logging from os import getenv +import ssl +from typing import Optional +from google.auth import environment_vars from google.auth import exceptions from google.auth.transport import _mtls_helper +_LOGGER = logging.getLogger(__name__) + + def has_default_client_cert_source(include_context_aware=True): """Check if default client SSL credentials exists on the device. @@ -60,7 +67,7 @@ def default_client_cert_source(): client certificate bytes and private key bytes, both in PEM format. Raises: - google.auth.exceptions.DefaultClientCertSourceError: If the default + google.auth.exceptions.MutualTLSChannelError: If the default client SSL credentials don't exist or are malformed. """ if not has_default_client_cert_source(include_context_aware=True): @@ -71,7 +78,12 @@ def default_client_cert_source(): def callback(): try: _, cert_bytes, key_bytes = _mtls_helper.get_client_cert_and_key() - except (OSError, RuntimeError, ValueError) as caught_exc: + except ( + exceptions.ClientCertError, + OSError, + RuntimeError, + ValueError, + ) as caught_exc: new_exc = exceptions.MutualTLSChannelError(caught_exc) raise new_exc from caught_exc @@ -96,7 +108,7 @@ def default_client_encrypted_cert_source(cert_path, key_path): returns the cert_path, key_path and passphrase bytes. Raises: - google.auth.exceptions.DefaultClientCertSourceError: If any problem + google.auth.exceptions.MutualTLSChannelError: If any problem occurs when loading or saving the client certificate and key. """ if not has_default_client_cert_source(include_context_aware=True): @@ -140,3 +152,163 @@ def should_use_client_cert(): bool: indicating whether the client certificate should be used for mTLS. """ return _mtls_helper.check_use_client_cert() + + +def load_client_cert_into_context( + ctx: ssl.SSLContext, + cert_bytes: bytes, + key_bytes: bytes, + passphrase: Optional[bytes] = None, +) -> None: + """Load a client certificate and key into an SSL context. + + Args: + ctx (ssl.SSLContext): The SSL context to load the certificate and key into. + cert_bytes (bytes): The client certificate bytes in PEM format. + key_bytes (bytes): The client private key bytes in PEM format. + passphrase (Optional[bytes]): The passphrase for the client private key. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If the SSL context is invalid, + or if loading the certificate and key fails. + """ + if ctx is None or not hasattr(ctx, "load_cert_chain"): + raise exceptions.MutualTLSChannelError( + "Failed to load client certificate and key for mTLS. The provided context " + "object is invalid or does not support loading certificate chains." + ) + + try: + with _mtls_helper.secure_cert_key_paths( + cert_bytes, key_bytes, passphrase=passphrase + ) as ( + cert_path, + key_path, + passphrase_val, + ): + ctx.load_cert_chain( + certfile=cert_path, keyfile=key_path, password=passphrase_val + ) + except ( + ssl.SSLError, + OSError, + ValueError, + RuntimeError, + ) as caught_exc: + new_exc = exceptions.MutualTLSChannelError( + "Failed to load client certificate and key for mTLS." + ) + raise new_exc from caught_exc + + +def make_client_cert_ssl_context( + cert_bytes: bytes, + key_bytes: bytes, + passphrase: Optional[bytes] = None, +) -> ssl.SSLContext: + """Create a default SSL context loaded with the client certificate and key. + + Args: + cert_bytes (bytes): The client certificate bytes in PEM format. + key_bytes (bytes): The client private key bytes in PEM format. + passphrase (Optional[bytes]): The passphrase for the client private key. + + Returns: + ssl.SSLContext: The SSL context loaded with the client certificate and key. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If loading the certificate and key fails. + """ + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase=passphrase) + return ctx + + +def load_default_client_cert(ctx: ssl.SSLContext) -> bool: + """Load the default client certificate and key into an SSL context if configured. + + If client certificates are enabled and a default client certificate source is + found, the certificate and key are loaded into the SSL context. + + Args: + ctx (ssl.SSLContext): The SSL context to load the default client certificate + and key into. + + Returns: + bool: True if client certificates are enabled and the default client + certificate was successfully loaded. False if client certificates + are disabled or if no default certificate source is configured. + + Raises: + google.auth.exceptions.ClientCertError: If the default client certificate + source exists but cannot be loaded or parsed. + google.auth.exceptions.MutualTLSChannelError: If the default client certificate + or key is malformed. + """ + if not should_use_client_cert() or not has_default_client_cert_source(): + return False + ( + has_cert, + cert_bytes, + key_bytes, + passphrase, + ) = _mtls_helper.get_client_ssl_credentials() + if not has_cert: + return False + load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase) + return True + + +def get_default_ssl_context() -> Optional[ssl.SSLContext]: + """Get a default SSL context loaded with the default client certificate. + + Returns: + ssl.SSLContext: An SSL context loaded with the default client + certificate, or None if client certificates are not configured + or available. + + Raises: + google.auth.exceptions.ClientCertError: If the default client certificate + source exists but cannot be loaded or parsed. + google.auth.exceptions.MutualTLSChannelError: If the default client certificate + or key is malformed. + """ + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + return ctx if load_default_client_cert(ctx) else None + + +def should_use_mtls_endpoint( + client_cert_available: Optional[bool] = None, +) -> bool: + """Determine whether to use an mTLS endpoint. + + This relies on the GOOGLE_API_USE_MTLS_ENDPOINT environment variable. If set to + "always", returns True. If set to "never", returns False. If set to "auto" + or unset, returns whether a client certificate is available. + + Args: + client_cert_available (bool): indicating if a client certificate + is available. If None, this is determined by checking if client + certificates are enabled and a default source is present. + + Returns: + bool: indicating if an mTLS endpoint should be used. + """ + if client_cert_available is None: + client_cert_available = should_use_client_cert() + + use_mtls_endpoint = getenv(environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT, "auto") + use_mtls_endpoint = use_mtls_endpoint.lower() + if use_mtls_endpoint == "always": + return True + if use_mtls_endpoint == "never": + return False + if use_mtls_endpoint == "auto": + return client_cert_available + + _LOGGER.warning( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value %r. Accepted " + "values: never, auto, always. Defaulting to auto.", + use_mtls_endpoint, + ) + return client_cert_available diff --git a/packages/google-auth/tests/transport/test_mtls.py b/packages/google-auth/tests/transport/test_mtls.py index 405cb496cad2..69cbe792575d 100644 --- a/packages/google-auth/tests/transport/test_mtls.py +++ b/packages/google-auth/tests/transport/test_mtls.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import ssl from unittest import mock import pytest # type: ignore @@ -135,6 +137,12 @@ def test_default_client_cert_source( with pytest.raises(exceptions.MutualTLSChannelError): callback() + # Test bad callback which throws ClientCertError. + get_client_cert_and_key.side_effect = exceptions.ClientCertError() + callback = mtls.default_client_cert_source() + with pytest.raises(exceptions.MutualTLSChannelError): + callback() + @mock.patch( "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True @@ -173,3 +181,258 @@ def test_should_use_client_cert(check_use_client_cert): check_use_client_cert.return_value = False assert not mtls.should_use_client_cert() + + +@contextlib.contextmanager +def _fake_secure_paths(cert_bytes, key_bytes, passphrase=None): + yield "cert_path", "key_path", passphrase + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_success(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + result = mtls.load_client_cert_into_context( + mock_ctx, b"cert", b"key", passphrase=b"passphrase" + ) + assert result is None + mock_ctx.load_cert_chain.assert_called_once_with( + certfile="cert_path", keyfile="key_path", password=b"passphrase" + ) + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_error(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_ctx.load_cert_chain.side_effect = ssl.SSLError("boom") + with pytest.raises(exceptions.MutualTLSChannelError) as exc_info: + mtls.load_client_cert_into_context(mock_ctx, b"cert", b"key") + assert "Failed to load client certificate and key" in str(exc_info.value) + assert isinstance(exc_info.value.__cause__, ssl.SSLError) + + +def test_load_client_cert_into_context_invalid_ctx(): + with pytest.raises(exceptions.MutualTLSChannelError) as exc_info: + mtls.load_client_cert_into_context(None, b"cert", b"key") + assert ( + "The provided context object is invalid or does not support loading certificate chains" + in str(exc_info.value) + ) + assert exc_info.value.__cause__ is None + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_load_chain_type_error(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_ctx.load_cert_chain.side_effect = TypeError("invalid password type") + with pytest.raises(TypeError) as exc_info: + mtls.load_client_cert_into_context(mock_ctx, b"cert", b"key") + assert "invalid password type" in str(exc_info.value) + + +@mock.patch("google.auth.transport.mtls.load_client_cert_into_context", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_make_client_cert_ssl_context(mock_create_context, mock_load_cert): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + + result = mtls.make_client_cert_ssl_context(b"cert", b"key", b"passphrase") + + assert result == mock_ctx + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_cert.assert_called_once_with( + mock_ctx, b"cert", b"key", passphrase=b"passphrase" + ) + + +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_disabled(mock_should_use): + mock_should_use.return_value = False + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_no_source(mock_should_use, mock_has_source): + mock_should_use.return_value = True + mock_has_source.return_value = False + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_no_cert( + mock_should_use, mock_has_source, mock_get_credentials +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.return_value = (False, None, None, None) + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch("google.auth.transport.mtls.load_client_cert_into_context", autospec=True) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_success( + mock_should_use, mock_has_source, mock_get_credentials, mock_load_cert +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.return_value = (True, b"cert", b"key", b"passphrase") + mock_ctx = mock.Mock(spec=ssl.SSLContext) + + assert mtls.load_default_client_cert(mock_ctx) is True + mock_load_cert.assert_called_once_with(mock_ctx, b"cert", b"key", b"passphrase") + + +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_propagates_client_cert_error( + mock_should_use, mock_has_source, mock_get_credentials +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.side_effect = exceptions.ClientCertError("credentials failure") + mock_ctx = mock.Mock(spec=ssl.SSLContext) + + with pytest.raises(exceptions.ClientCertError) as exc_info: + mtls.load_default_client_cert(mock_ctx) + assert "credentials failure" in str(exc_info.value) + + +@mock.patch("google.auth.transport.mtls.load_default_client_cert", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_get_default_ssl_context_configured(mock_create_context, mock_load_default): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + mock_load_default.return_value = True + + result = mtls.get_default_ssl_context() + + assert result == mock_ctx + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_default.assert_called_once_with(mock_ctx) + + +@mock.patch("google.auth.transport.mtls.load_default_client_cert", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_get_default_ssl_context_unconfigured(mock_create_context, mock_load_default): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + mock_load_default.return_value = False + + result = mtls.get_default_ssl_context() + + assert result is None + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_default.assert_called_once_with(mock_ctx) + + +@pytest.mark.parametrize( + "env_val,client_cert_available,expected", + [ + ("always", True, True), + ("always", False, True), + ("never", True, False), + ("never", False, False), + ("auto", True, True), + ("auto", False, False), + (None, True, True), # Defaults to auto + (None, False, False), # Defaults to auto + ("ALWAYS", True, True), + ("ALWAYS", False, True), + ("NEVER", True, False), + ("NEVER", False, False), + ("AUTO", True, True), + ("AUTO", False, False), + ("invalid_val", True, True), + ("invalid_val", False, False), + ], +) +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint( + mock_getenv, env_val, client_cert_available, expected +): + mock_getenv.side_effect = ( + lambda var, default=None: env_val + if (var == "GOOGLE_API_USE_MTLS_ENDPOINT" and env_val is not None) + else default + ) + result = mtls.should_use_mtls_endpoint(client_cert_available) + assert result == expected + + +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint_invalid_value(mock_getenv, caplog): + mock_getenv.side_effect = ( + lambda var, default=None: "invalid_value" + if var == "GOOGLE_API_USE_MTLS_ENDPOINT" + else default + ) + with caplog.at_level("WARNING"): + assert mtls.should_use_mtls_endpoint(True) is True + assert "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value" in caplog.text + assert "Defaulting to auto" in caplog.text + + caplog.clear() + + with caplog.at_level("WARNING"): + assert mtls.should_use_mtls_endpoint(False) is False + assert "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value" in caplog.text + assert "Defaulting to auto" in caplog.text + + +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint_default_client_cert( + mock_getenv, mock_should_use_client_cert +): + mock_getenv.side_effect = ( + lambda var, default=None: "auto" + if var == "GOOGLE_API_USE_MTLS_ENDPOINT" + else default + ) + mock_should_use_client_cert.return_value = True + assert mtls.should_use_mtls_endpoint() is True + mock_should_use_client_cert.assert_called_once() + + mock_should_use_client_cert.reset_mock() + + mock_should_use_client_cert.return_value = False + assert mtls.should_use_mtls_endpoint() is False + mock_should_use_client_cert.assert_called_once()