Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/google-auth/google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@
"GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"
)
"""Environment variable to prevent agent token sharing for GCP services."""

GOOGLE_API_USE_MTLS_ENDPOINT = "GOOGLE_API_USE_MTLS_ENDPOINT"
"""Environment variable controlling whether to use mTLS endpoint or not."""
178 changes: 175 additions & 3 deletions packages/google-auth/google/auth/transport/mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@

"""Utilites for mutual TLS."""

import logging
from os import getenv
import ssl
from typing import Optional

from google.auth import environment_vars
from google.auth import exceptions
from google.auth.transport import _mtls_helper


_LOGGER = logging.getLogger(__name__)


def has_default_client_cert_source(include_context_aware=True):
"""Check if default client SSL credentials exists on the device.

Expand Down Expand Up @@ -60,7 +67,7 @@ def default_client_cert_source():
client certificate bytes and private key bytes, both in PEM format.

Raises:
google.auth.exceptions.DefaultClientCertSourceError: If the default
google.auth.exceptions.MutualTLSChannelError: If the default
client SSL credentials don't exist or are malformed.
"""
if not has_default_client_cert_source(include_context_aware=True):
Expand All @@ -71,7 +78,12 @@ def default_client_cert_source():
def callback():
try:
_, cert_bytes, key_bytes = _mtls_helper.get_client_cert_and_key()
except (OSError, RuntimeError, ValueError) as caught_exc:
except (
exceptions.ClientCertError,
OSError,
RuntimeError,
ValueError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

Expand All @@ -96,7 +108,7 @@ def default_client_encrypted_cert_source(cert_path, key_path):
returns the cert_path, key_path and passphrase bytes.

Raises:
google.auth.exceptions.DefaultClientCertSourceError: If any problem
google.auth.exceptions.MutualTLSChannelError: If any problem
occurs when loading or saving the client certificate and key.
"""
if not has_default_client_cert_source(include_context_aware=True):
Expand Down Expand Up @@ -140,3 +152,163 @@ def should_use_client_cert():
bool: indicating whether the client certificate should be used for mTLS.
"""
return _mtls_helper.check_use_client_cert()


def load_client_cert_into_context(
ctx: ssl.SSLContext,
cert_bytes: bytes,
key_bytes: bytes,
passphrase: Optional[bytes] = None,
) -> None:
"""Load a client certificate and key into an SSL context.

Args:
ctx (ssl.SSLContext): The SSL context to load the certificate and key into.
cert_bytes (bytes): The client certificate bytes in PEM format.
key_bytes (bytes): The client private key bytes in PEM format.
passphrase (Optional[bytes]): The passphrase for the client private key.

Raises:
google.auth.exceptions.MutualTLSChannelError: If the SSL context is invalid,
or if loading the certificate and key fails.
"""
if ctx is None or not hasattr(ctx, "load_cert_chain"):
raise exceptions.MutualTLSChannelError(
"Failed to load client certificate and key for mTLS. The provided context "
"object is invalid or does not support loading certificate chains."
)

try:
with _mtls_helper.secure_cert_key_paths(
cert_bytes, key_bytes, passphrase=passphrase
) as (
cert_path,
key_path,
passphrase_val,
):
ctx.load_cert_chain(
certfile=cert_path, keyfile=key_path, password=passphrase_val
)
except (
ssl.SSLError,
OSError,
ValueError,
RuntimeError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(
"Failed to load client certificate and key for mTLS."
)
raise new_exc from caught_exc


def make_client_cert_ssl_context(
cert_bytes: bytes,
key_bytes: bytes,
passphrase: Optional[bytes] = None,
) -> ssl.SSLContext:
"""Create a default SSL context loaded with the client certificate and key.

Args:
cert_bytes (bytes): The client certificate bytes in PEM format.
key_bytes (bytes): The client private key bytes in PEM format.
passphrase (Optional[bytes]): The passphrase for the client private key.

Returns:
ssl.SSLContext: The SSL context loaded with the client certificate and key.

Raises:
google.auth.exceptions.MutualTLSChannelError: If loading the certificate and key fails.
"""
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase=passphrase)
return ctx


def load_default_client_cert(ctx: ssl.SSLContext) -> bool:
"""Load the default client certificate and key into an SSL context if configured.

If client certificates are enabled and a default client certificate source is
found, the certificate and key are loaded into the SSL context.

Args:
ctx (ssl.SSLContext): The SSL context to load the default client certificate
and key into.

Returns:
bool: True if client certificates are enabled and the default client
certificate was successfully loaded. False if client certificates
are disabled or if no default certificate source is configured.

Raises:
google.auth.exceptions.ClientCertError: If the default client certificate
source exists but cannot be loaded or parsed.
google.auth.exceptions.MutualTLSChannelError: If the default client certificate
or key is malformed.
"""
if not should_use_client_cert() or not has_default_client_cert_source():
return False
(
has_cert,
cert_bytes,
key_bytes,
passphrase,
) = _mtls_helper.get_client_ssl_credentials()
if not has_cert:
return False
load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase)
return True


def get_default_ssl_context() -> Optional[ssl.SSLContext]:
"""Get a default SSL context loaded with the default client certificate.

Returns:
ssl.SSLContext: An SSL context loaded with the default client
certificate, or None if client certificates are not configured
or available.

Raises:
google.auth.exceptions.ClientCertError: If the default client certificate
source exists but cannot be loaded or parsed.
google.auth.exceptions.MutualTLSChannelError: If the default client certificate
or key is malformed.
"""
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
return ctx if load_default_client_cert(ctx) else None


def should_use_mtls_endpoint(
client_cert_available: Optional[bool] = None,
) -> bool:
"""Determine whether to use an mTLS endpoint.

This relies on the GOOGLE_API_USE_MTLS_ENDPOINT environment variable. If set to
"always", returns True. If set to "never", returns False. If set to "auto"
or unset, returns whether a client certificate is available.

Args:
client_cert_available (bool): indicating if a client certificate
is available. If None, this is determined by checking if client
certificates are enabled and a default source is present.

Returns:
bool: indicating if an mTLS endpoint should be used.
"""
if client_cert_available is None:
client_cert_available = should_use_client_cert()

use_mtls_endpoint = getenv(environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT, "auto")
use_mtls_endpoint = use_mtls_endpoint.lower()
if use_mtls_endpoint == "always":
return True
if use_mtls_endpoint == "never":
return False
if use_mtls_endpoint == "auto":
return client_cert_available

_LOGGER.warning(
"Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value %r. Accepted "
"values: never, auto, always. Defaulting to auto.",
use_mtls_endpoint,
)
return client_cert_available
Loading
Loading