diff --git a/acouchbase_analytics/cluster.py b/acouchbase_analytics/cluster.py index bf07129..0dafd3e 100644 --- a/acouchbase_analytics/cluster.py +++ b/acouchbase_analytics/cluster.py @@ -158,6 +158,22 @@ def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaita """ # noqa: E501 return self._impl.start_query(statement, *args, **kwargs) + async def set_credential(self, credential: Credential) -> None: + """Replace the credential used for subsequent HTTP requests. + + Allows updating credentials (in particular, rotating a JWT) without restarting + the application. The new credential must be of the same type as the current + credential. + + Args: + credential: The new :class:`.Credential` to use. + + Raises: + TypeError: If the new credential is a different type than the current + credential. + """ + await self._impl.set_credential(credential) + async def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/acouchbase_analytics/cluster.pyi b/acouchbase_analytics/cluster.pyi index 1745c5f..70c3fa1 100644 --- a/acouchbase_analytics/cluster.pyi +++ b/acouchbase_analytics/cluster.pyi @@ -91,6 +91,7 @@ class AsyncCluster: ) -> Awaitable[AsyncQueryHandle]: ... @overload def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ... + def set_credential(self, credential: Credential) -> Awaitable[None]: ... def shutdown(self) -> Awaitable[None]: ... @overload @classmethod diff --git a/acouchbase_analytics/protocol/_core/client_adapter.py b/acouchbase_analytics/protocol/_core/client_adapter.py index c6d00c5..6751a76 100644 --- a/acouchbase_analytics/protocol/_core/client_adapter.py +++ b/acouchbase_analytics/protocol/_core/client_adapter.py @@ -20,11 +20,12 @@ from typing import Optional, cast from uuid import uuid4 -from httpx import URL, AsyncClient, BasicAuth, Response +from httpx import URL, AsyncClient, Response -from couchbase_analytics.common.credential import Credential +from couchbase_analytics.common.credential import Credential, _CredentialHolder from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder @@ -47,6 +48,7 @@ def __init__( self._opts_builder = OptionsBuilder() kwargs['logger_name'] = self.logger_name self._conn_details = _ConnectionDetails.create(self._opts_builder, http_endpoint, credential, options, **kwargs) + self._credential_holder = _CredentialHolder(credential) # PYCO-67: Do we want to allow supporting custom HTTP transports? self._http_transport_cls = None @@ -78,6 +80,13 @@ def connection_details(self) -> _ConnectionDetails: """ return self._conn_details + @property + def credential_holder(self) -> _CredentialHolder: + """ + **INTERNAL** + """ + return self._credential_holder + @property def default_deserializer(self) -> Deserializer: """ @@ -136,6 +145,7 @@ async def create_client(self) -> None: **INTERNAL** """ if not hasattr(self, '_client'): + auth = DynamicCredentialAuth(self._credential_holder) if self._conn_details.is_secure(): if self._conn_details.ssl_context is None: raise ValueError('SSL context is required for secure connections.') @@ -144,14 +154,14 @@ async def create_client(self) -> None: transport = self._http_transport_cls(verify=self._conn_details.ssl_context) self._client = AsyncClient( verify=self._conn_details.ssl_context, - auth=BasicAuth(*self._conn_details.credential), + auth=auth, transport=transport, ) else: transport = None if self._http_transport_cls is not None: transport = self._http_transport_cls() - self._client = AsyncClient(auth=BasicAuth(*self._conn_details.credential), transport=transport) + self._client = AsyncClient(auth=auth, transport=transport) self.log_message( (f'Cluster HTTP client created: connection_details={self._conn_details.get_init_details()}'), LogLevel.INFO, @@ -195,5 +205,10 @@ def reset_client(self) -> None: if hasattr(self, '_client'): del self._client + async def update_credential(self, new_credential: Credential) -> None: + self._credential_holder.replace(new_credential) + # Future mTLS: await close_client(), rebuild SSL context, await create_client(). + self.log_message('Cluster HTTP credential updated', LogLevel.INFO) + logger = logging.getLogger(_AsyncClientAdapter.LOGGER_NAME) diff --git a/acouchbase_analytics/protocol/cluster.py b/acouchbase_analytics/protocol/cluster.py index c28661b..31c5e17 100644 --- a/acouchbase_analytics/protocol/cluster.py +++ b/acouchbase_analytics/protocol/cluster.py @@ -99,6 +99,9 @@ async def shutdown(self) -> None: else: self.client_adapter.log_message('Cluster does not have a connection. Ignoring shutdown.', LogLevel.WARNING) + async def set_credential(self, credential: Credential) -> None: + await self._client_adapter.update_credential(credential) + async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQueryResult: if not self.has_client: self.client_adapter.log_message( diff --git a/acouchbase_analytics/protocol/cluster.pyi b/acouchbase_analytics/protocol/cluster.pyi index 3fd9f3a..0a8e67d 100644 --- a/acouchbase_analytics/protocol/cluster.pyi +++ b/acouchbase_analytics/protocol/cluster.pyi @@ -51,7 +51,8 @@ class AsyncCluster: def client_adapter(self) -> _AsyncClientAdapter: ... @property def connected(self) -> bool: ... - def shutdown(self) -> Awaitable[None]: ... + async def set_credential(self, credential: Credential) -> None: ... + async def shutdown(self) -> None: ... def database(self, name: str) -> AsyncDatabase: ... @overload def execute_query(self, statement: str) -> Awaitable[AsyncQueryResult]: ... diff --git a/acouchbase_analytics/tests/credential_t.py b/acouchbase_analytics/tests/credential_t.py new file mode 100644 index 0000000..623e431 --- /dev/null +++ b/acouchbase_analytics/tests/credential_t.py @@ -0,0 +1,227 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from base64 import b64encode +from typing import Any + +import pytest +from httpx import Request + +from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter +from couchbase_analytics.credential import Credential +from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth + +_SAMPLE_JWT = 'eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature' + + +def _authorization_header(client: Any) -> str: + auth = DynamicCredentialAuth(client.credential_holder) + request_url = f'{client.connection_details.url.get_formatted_url()}{client.analytics_path}' + req = Request('POST', request_url) + flow = auth.auth_flow(req) + dispatched = next(flow) + return dispatched.headers['Authorization'] + + +class CredentialTestSuite: + TEST_MANIFEST = [ + 'test_jwt_credential_creation', + 'test_jwt_credential_strips_token', + 'test_jwt_credential_rejects_non_string', + 'test_credential_direct_construction_with_jwt', + 'test_credential_direct_construction_with_password', + 'test_credential_rejects_unknown_kwargs', + 'test_credential_hides_internal_details', + 'test_credential_from_callable_with_jwt', + 'test_jwt_credential_repr_redacts_token', + 'test_jwt_credential_rejects_http_endpoint', + 'test_jwt_credential_accepts_https_endpoint', + 'test_password_credential_http_authorization_header', + 'test_password_credential_repr_redacts_password', + 'test_dynamic_auth_sets_header_from_current_credential', + 'test_async_dynamic_auth_sets_header_from_current_credential', + 'test_dynamic_auth_picks_up_rotated_credential', + 'test_set_credential_same_type_updates_state', + 'test_set_credential_password_to_jwt_fails', + 'test_set_credential_jwt_to_password_fails', + 'test_set_credential_failure_does_not_change_state', + ] + + def test_jwt_credential_creation(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_jwt_credential_strips_token(self) -> None: + cred = Credential.from_jwt(f' {_SAMPLE_JWT}\n') + client = _AsyncClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + @pytest.mark.parametrize( + ('bad_token', 'expected_exc'), + [ + (12345, TypeError), + (None, TypeError), + (b'bytes.jwt.token', TypeError), + (1.5, TypeError), + (['a', 'b'], TypeError), + ('', ValueError), + (' \n', ValueError), + ], + ) + def test_jwt_credential_rejects_non_string(self, bad_token: object, expected_exc: type) -> None: + with pytest.raises(expected_exc): + Credential.from_jwt(bad_token) # type: ignore[arg-type] + + def test_credential_direct_construction_with_jwt(self) -> None: + cred = Credential(jwt_token=f' {_SAMPLE_JWT}\n') + client = _AsyncClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_credential_direct_construction_with_password(self) -> None: + cred = Credential(username='Administrator', password='password') + client = _AsyncClientAdapter('http://localhost', cred) + expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii') + assert _authorization_header(client) == expected + + def test_credential_rejects_unknown_kwargs(self) -> None: + with pytest.raises(TypeError, match='unexpected keyword argument'): + Credential(usernme='Administrator', password='password') # type: ignore[call-arg] + with pytest.raises(TypeError, match='unexpected keyword argument'): + Credential(jwt_token=_SAMPLE_JWT, extra='ignored') # type: ignore[call-arg] + + def test_credential_hides_internal_details(self) -> None: + public_attrs = {name for name in dir(Credential.from_jwt(_SAMPLE_JWT)) if not name.startswith('_')} + assert public_attrs == {'from_callable', 'from_jwt', 'from_username_and_password'} + + def test_credential_from_callable_with_jwt(self) -> None: + cred = Credential.from_callable(lambda: Credential.from_jwt(_SAMPLE_JWT)) + client = _AsyncClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_jwt_credential_repr_redacts_token(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + rendered = repr(cred) + assert _SAMPLE_JWT not in rendered + assert '****' in rendered + + def test_jwt_credential_rejects_http_endpoint(self) -> None: + with pytest.raises(ValueError, match='require a secure'): + _AsyncClientAdapter('http://localhost', Credential.from_jwt(_SAMPLE_JWT)) + + def test_jwt_credential_accepts_https_endpoint(self) -> None: + client = _AsyncClientAdapter('https://localhost', Credential.from_jwt(_SAMPLE_JWT)) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_password_credential_http_authorization_header(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _AsyncClientAdapter('http://localhost', cred) + expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii') + assert _authorization_header(client) == expected + + def test_password_credential_repr_redacts_password(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'super-secret') + rendered = repr(cred) + assert 'super-secret' not in rendered + assert '****' in rendered + assert 'Administrator' in rendered + + def test_dynamic_auth_sets_header_from_current_credential(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + auth = DynamicCredentialAuth(client.credential_holder) + + req = Request('POST', 'https://localhost/api/v1/request') + flow = auth.auth_flow(req) + dispatched = next(flow) + assert dispatched.headers['Authorization'] == f'Bearer {_SAMPLE_JWT}' + + @pytest.mark.anyio + async def test_async_dynamic_auth_sets_header_from_current_credential(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + auth = DynamicCredentialAuth(client.credential_holder) + + request_url = f'{client.connection_details.url.get_formatted_url()}{client.analytics_path}' + req = Request('POST', request_url) + flow = auth.async_auth_flow(req) + dispatched = await flow.__anext__() + assert dispatched.headers['Authorization'] == f'Bearer {_SAMPLE_JWT}' + + @pytest.mark.anyio + async def test_dynamic_auth_picks_up_rotated_credential(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + auth = DynamicCredentialAuth(client.credential_holder) + + new_token = 'rotated.jwt.token' + await client.update_credential(Credential.from_jwt(new_token)) + + req = Request('POST', 'https://localhost/api/v1/request') + flow = auth.auth_flow(req) + dispatched = next(flow) + assert dispatched.headers['Authorization'] == f'Bearer {new_token}' + + @pytest.mark.anyio + async def test_set_credential_same_type_updates_state(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + new_token = 'fresh.jwt.token' + await client.update_credential(Credential.from_jwt(new_token)) + + assert _authorization_header(client) == f'Bearer {new_token}' + + @pytest.mark.anyio + async def test_set_credential_password_to_jwt_fails(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _AsyncClientAdapter('http://localhost', cred) + with pytest.raises(TypeError, match='Cannot switch credential type'): + await client.update_credential(Credential.from_jwt(_SAMPLE_JWT)) + + @pytest.mark.anyio + async def test_set_credential_jwt_to_password_fails(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _AsyncClientAdapter('https://localhost', cred) + with pytest.raises(TypeError, match='Cannot switch credential type'): + await client.update_credential(Credential.from_username_and_password('Administrator', 'password')) + + @pytest.mark.anyio + async def test_set_credential_failure_does_not_change_state(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _AsyncClientAdapter('http://localhost', cred) + original_header = _authorization_header(client) + + with pytest.raises(TypeError): + await client.update_credential(Credential.from_jwt(_SAMPLE_JWT)) + + assert _authorization_header(client) == original_header + + +class CredentialTests(CredentialTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(CredentialTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(CredentialTests) if valid_test_method(meth)] + test_list = set(CredentialTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') diff --git a/conftest.py b/conftest.py index 42a235f..e912d94 100644 --- a/conftest.py +++ b/conftest.py @@ -26,6 +26,7 @@ _UNIT_TESTS = [ 'acouchbase_analytics/tests/connection_t.py::ConnectionTests', + 'acouchbase_analytics/tests/credential_t.py::CredentialTests', 'acouchbase_analytics/tests/json_parsing_t.py::JsonParsingTests', 'acouchbase_analytics/tests/options_t.py::ClusterOptionsTests', 'acouchbase_analytics/tests/query_options_t.py::ClusterQueryOptionsTests', @@ -33,6 +34,7 @@ 'acouchbase_analytics/tests/test_server_t.py::ClusterTestServerTests', 'acouchbase_analytics/tests/test_server_t.py::ScopeTestServerTests', 'couchbase_analytics/tests/connection_t.py::ConnectionTests', + 'couchbase_analytics/tests/credential_t.py::CredentialTests', 'couchbase_analytics/tests/duration_parsing_t.py::DurationParsingTests', 'couchbase_analytics/tests/json_parsing_t.py::JsonParsingTests', 'couchbase_analytics/tests/options_t.py::ClusterOptionsTests', diff --git a/couchbase_analytics/cluster.py b/couchbase_analytics/cluster.py index 9c133de..3736f21 100644 --- a/couchbase_analytics/cluster.py +++ b/couchbase_analytics/cluster.py @@ -156,6 +156,22 @@ def start_query(self, statement: str, *args: object, **kwargs: object) -> Blocki """ # noqa: E501 return self._impl.start_query(statement, *args, **kwargs) + def set_credential(self, credential: Credential) -> None: + """Replace the credential used for subsequent HTTP requests. + + Allows updating credentials (in particular, rotating a JWT) without restarting + the application. The new credential must be of the same type as the current + credential. + + Args: + credential: The new :class:`.Credential` to use. + + Raises: + TypeError: If the new credential is a different type than the current + credential. + """ + self._impl.set_credential(credential) + def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/couchbase_analytics/cluster.pyi b/couchbase_analytics/cluster.pyi index 0b32759..2bf2472 100644 --- a/couchbase_analytics/cluster.pyi +++ b/couchbase_analytics/cluster.pyi @@ -142,6 +142,7 @@ class Cluster: ) -> BlockingQueryHandle: ... @overload def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... + def set_credential(self, credential: Credential) -> None: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/common/credential.py b/couchbase_analytics/common/credential.py index 72a5b4b..52d246a 100644 --- a/couchbase_analytics/common/credential.py +++ b/couchbase_analytics/common/credential.py @@ -13,50 +13,79 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Credentials for the Analytics SDK.""" from __future__ import annotations -from typing import Callable, Dict, Tuple +from base64 import b64encode +from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional +if TYPE_CHECKING: + from httpx import Request -class Credential: - """Create a Credential instance. - A Credential is required in order to connect to a Analytics endpoint. +class Credential: + """A credential for authenticating with a Couchbase Analytics endpoint. - .. important:: - Use the the provided classmethods to create a :class:`.Credential` instance. + Construct via a factory classmethod or direct keyword arguments:: + Credential(username='Administrator', password='swordfish') + Credential(jwt_token='eyJ...') + Credential.from_username_and_password('Administrator', 'swordfish') + Credential.from_jwt('eyJ...') + Credential.from_callable(lambda: Credential.from_jwt(get_token())) """ - def __init__(self, **kwargs: str) -> None: - username = kwargs.pop('username', None) - password = kwargs.pop('password', None) - - if username is None: - raise ValueError('Must provide a username.') - if not isinstance(username, str): - raise ValueError('The username must be a str.') - - if password is None: - raise ValueError('Must provide a password.') - if not isinstance(password, str): - raise ValueError('The password must be a str.') - - self._username = username - self._password = password - - def asdict(self) -> Dict[str, str]: - """ - **INTERNAL** - """ - return {'username': self._username, 'password': self._password} - - def astuple(self) -> Tuple[bytes, bytes]: - """ - **INTERNAL** - """ - return self._username.encode(), self._password.encode() + __slots__ = ('_kind', '_header', '_password', '_token', '_username') + + _kind: Literal['password', 'jwt'] + _header: str + _password: Optional[str] + _token: Optional[str] + _username: Optional[str] + + def __init__( + self, + *, + username: Optional[str] = None, + password: Optional[str] = None, + jwt_token: Optional[str] = None, + ) -> None: + if jwt_token is not None: + if username is not None or password is not None: + raise ValueError('Cannot provide both a JWT token and username/password.') + if not isinstance(jwt_token, str): + raise TypeError('The JWT token must be a str.') + jwt_token = jwt_token.strip() + if not jwt_token: + raise ValueError('The JWT token must not be empty.') + self._kind = 'jwt' + self._header = f'Bearer {jwt_token}' + self._password = None + self._token = jwt_token + self._username = None + elif username is not None or password is not None: + if username is None: + raise ValueError('Must provide a username.') + if password is None: + raise ValueError('Must provide a password.') + if not isinstance(username, str): + raise TypeError('The username must be a str.') + if not isinstance(password, str): + raise TypeError('The password must be a str.') + self._kind = 'password' + self._header = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode('ascii') + self._password = password + self._token = None + self._username = username + else: + raise ValueError('Must provide either jwt_token or username and password.') + + def _asdict(self) -> Dict[str, str]: + """**INTERNAL**""" + if self._kind == 'jwt': + return {'jwt_token': self._token or ''} + return {'username': self._username or '', 'password': self._password or ''} @classmethod def from_username_and_password(cls, username: str, password: str) -> Credential: @@ -65,39 +94,95 @@ def from_username_and_password(cls, username: str, password: str) -> Credential: Args: username: The username for the Analytics endpoint. password: The password for the Analytics endpoint. + """ + if not isinstance(username, str): + raise TypeError('The username must be a str.') + if not isinstance(password, str): + raise TypeError('The password must be a str.') + return cls(username=username, password=password) - Returns: - A Credential instance. + @classmethod + def from_jwt(cls, token: str) -> Credential: + """Create a :class:`.Credential` from a JSON Web Token (JWT). + + The SDK sends an ``Authorization: Bearer `` header on every HTTP request. + JWT credentials may only be used with an ``https://`` endpoint. + + .. note:: + JWT credentials have a limited validity period. To avoid authentication + failures, periodically pass a fresh credential via + :meth:`~couchbase_analytics.cluster.Cluster.set_credential`. + + Args: + token: The JSON Web Token. """ - return Credential(username=username, password=password) + if not isinstance(token, str): + raise TypeError('The JWT token must be a str.') + return cls(jwt_token=token) @classmethod def from_callable(cls, callback: Callable[[], Credential]) -> Credential: - """Create a :class:`.Credential` from provided callback. + """Create a :class:`.Credential` by invoking a callback. - The callback is + The callback is invoked once at construction time; it is not retained for + dynamic credential lookup. To pick up a refreshed credential later, pass + a fresh :class:`.Credential` to + :meth:`~couchbase_analytics.cluster.Cluster.set_credential`. Args: callback: Callback that returns a :class:`.Credential`. - Returns: - A Credential instance. + Example:: + + cred = Credential.from_callable( + lambda: Credential.from_username_and_password( + os.environ['USERNAME'], os.environ['PASSWORD'] + ) + ) + """ + return cls(**callback()._asdict()) - Example: - Retrieve credentials from environment variables:: + # Internal contract for the transport layer. Credential rotation is + # last-writer-wins; each request sees a fully-constructed credential. - def _cred_from_env() -> Credential: - from os import getenv - return Credential.from_username_and_password(getenv('PYCBCC_USERNAME'), - getenv('PYCBCC_PW')) + def _apply_to_request(self, request: Request) -> None: + request.headers['Authorization'] = self._header - cred = Credential.from_callable(_cred_from_env) + def _check_endpoint_compatible(self, is_secure: bool) -> None: + if self._kind == 'jwt' and not is_secure: + raise ValueError('JWT credentials require a secure (https) connection.') - """ - return Credential(**callback().asdict()) + def _check_replaceable_with(self, new_credential: Credential) -> None: + if self._kind != new_credential._kind: + raise TypeError( + f'Cannot switch credential type at runtime; current type is ' + f'{self._kind}, new type is {new_credential._kind}.' + ) def __repr__(self) -> str: - return f'Credential(username={self._username}, password=****)' + if self._kind == 'password': + return f'Credential(username={self._username!r}, password=****)' + return 'Credential(jwt_token=****)' + + +class _CredentialHolder: + """**INTERNAL** Owns the mutable current credential; transport calls + apply_to_request per request and replace for rotation. + """ + + __slots__ = ('_credential',) + + def __init__(self, credential: Credential) -> None: + self._credential = credential + + @property + def credential(self) -> Credential: + return self._credential + + def apply_to_request(self, request: Request) -> None: + self._credential._apply_to_request(request) - def __str__(self) -> str: - return self.__repr__() + def replace(self, new_credential: Credential) -> None: + # GIL-atomic store; concurrent rotators are last-writer-wins. + self._credential._check_replaceable_with(new_credential) + self._credential = new_credential diff --git a/couchbase_analytics/protocol/_core/auth.py b/couchbase_analytics/protocol/_core/auth.py new file mode 100644 index 0000000..b5b5d2a --- /dev/null +++ b/couchbase_analytics/protocol/_core/auth.py @@ -0,0 +1,43 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING, AsyncGenerator, Generator + +from httpx import Auth, Request, Response + +if TYPE_CHECKING: + from couchbase_analytics.common.credential import _CredentialHolder + + +class DynamicCredentialAuth(Auth): + """httpx ``Auth`` that reads the current credential from a ``_CredentialHolder`` at + request time, so rotating a credential via ``Cluster.set_credential`` takes effect + immediately without rebuilding the HTTP client. + """ + + def __init__(self, credential_holder: _CredentialHolder) -> None: + self._credential_holder = credential_holder + + def auth_flow(self, request: Request) -> Generator[Request, Response, None]: + self._credential_holder.apply_to_request(request) + yield request + + async def async_auth_flow(self, request: Request) -> AsyncGenerator[Request, Response]: + # Mirror of auth_flow — applying a credential is non-blocking. + self._credential_holder.apply_to_request(request) + yield request diff --git a/couchbase_analytics/protocol/_core/client_adapter.py b/couchbase_analytics/protocol/_core/client_adapter.py index 04fc078..143a439 100644 --- a/couchbase_analytics/protocol/_core/client_adapter.py +++ b/couchbase_analytics/protocol/_core/client_adapter.py @@ -20,11 +20,12 @@ from typing import Optional, cast from uuid import uuid4 -from httpx import URL, BasicAuth, Client, Response +from httpx import URL, Client, Response -from couchbase_analytics.common.credential import Credential +from couchbase_analytics.common.credential import Credential, _CredentialHolder from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder @@ -49,6 +50,7 @@ def __init__( self._http_transport_cls = None kwargs['logger_name'] = self.logger_name self._conn_details = _ConnectionDetails.create(self._opts_builder, http_endpoint, credential, options, **kwargs) + self._credential_holder = _CredentialHolder(credential) @property def analytics_path(self) -> str: @@ -78,6 +80,13 @@ def connection_details(self) -> _ConnectionDetails: """ return self._conn_details + @property + def credential_holder(self) -> _CredentialHolder: + """ + **INTERNAL** + """ + return self._credential_holder + @property def default_deserializer(self) -> Deserializer: """ @@ -136,7 +145,7 @@ def create_client(self) -> None: **INTERNAL** """ if not hasattr(self, '_client'): - auth = BasicAuth(*self._conn_details.credential) + auth = DynamicCredentialAuth(self._credential_holder) if self._conn_details.is_secure(): if self._conn_details.ssl_context is None: raise ValueError('SSL context is required for secure connections.') @@ -187,5 +196,10 @@ def reset_client(self) -> None: if hasattr(self, '_client'): del self._client + def update_credential(self, new_credential: Credential) -> None: + self._credential_holder.replace(new_credential) + # Future mTLS: rebuild SSL context + httpx Client here. + self.log_message('Cluster HTTP credential updated', LogLevel.INFO) + logger = logging.getLogger(_ClientAdapter.LOGGER_NAME) diff --git a/couchbase_analytics/protocol/cluster.py b/couchbase_analytics/protocol/cluster.py index 4e77b62..a285b08 100644 --- a/couchbase_analytics/protocol/cluster.py +++ b/couchbase_analytics/protocol/cluster.py @@ -122,6 +122,9 @@ def shutdown(self) -> None: else: self._client_adapter.log_message('Cluster does not have a connection, no need to shutdown.', LogLevel.INFO) + def set_credential(self, credential: Credential) -> None: + self._client_adapter.update_credential(credential) + def execute_query( self, statement: str, *args: object, **kwargs: object ) -> Union[BlockingQueryResult, Future[BlockingQueryResult]]: diff --git a/couchbase_analytics/protocol/cluster.pyi b/couchbase_analytics/protocol/cluster.pyi index 206bb55..169dd5b 100644 --- a/couchbase_analytics/protocol/cluster.pyi +++ b/couchbase_analytics/protocol/cluster.pyi @@ -147,6 +147,7 @@ class Cluster: ) -> BlockingQueryHandle: ... @overload def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... + def set_credential(self, credential: Credential) -> None: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/protocol/connection.py b/couchbase_analytics/protocol/connection.py index 031f511..dc4255d 100644 --- a/couchbase_analytics/protocol/connection.py +++ b/couchbase_analytics/protocol/connection.py @@ -155,13 +155,13 @@ def parse_query_str_options( @dataclass class _ConnectionDetails: - """ - **INTERNAL** + """**INTERNAL** Cluster-level config: URL, options, SSL context. + + Credential state lives in :class:`._CredentialHolder`, not here. """ url: RequestURL cluster_options: ClusterOptionsTransformedKwargs - credential: Tuple[bytes, bytes] default_deserializer: Deserializer ssl_context: Optional[ssl.SSLContext] = None sni_hostname: Optional[str] = None @@ -202,7 +202,8 @@ def get_query_timeout(self) -> float: def is_secure(self) -> bool: return self.url.scheme == 'https' - def validate_security_options(self) -> None: # noqa: C901 + def validate_security_options(self, credential: Credential) -> None: # noqa: C901 + credential._check_endpoint_compatible(self.is_secure()) security_opts: Optional[SecurityOptionsTransformedKwargs] = self.cluster_options.get('security_options') if security_opts is not None: # separate between value options and boolean option (trust_only_capella) @@ -275,6 +276,6 @@ def create( if default_deserializer is None: default_deserializer = DefaultJsonDeserializer() - conn_dtls = cls(url, cluster_opts, credential.astuple(), default_deserializer, logger_name=logger_name) - conn_dtls.validate_security_options() + conn_dtls = cls(url, cluster_opts, default_deserializer, logger_name=logger_name) + conn_dtls.validate_security_options(credential) return conn_dtls diff --git a/couchbase_analytics/tests/credential_t.py b/couchbase_analytics/tests/credential_t.py new file mode 100644 index 0000000..998a265 --- /dev/null +++ b/couchbase_analytics/tests/credential_t.py @@ -0,0 +1,210 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from base64 import b64encode +from typing import Any + +import pytest +from httpx import Request + +from couchbase_analytics.credential import Credential +from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth +from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter + +_SAMPLE_JWT = 'eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature' + + +def _authorization_header(client: Any) -> str: + auth = DynamicCredentialAuth(client.credential_holder) + request_url = f'{client.connection_details.url.get_formatted_url()}{client.analytics_path}' + req = Request('POST', request_url) + flow = auth.auth_flow(req) + dispatched = next(flow) + return dispatched.headers['Authorization'] + + +class CredentialTestSuite: + TEST_MANIFEST = [ + 'test_jwt_credential_creation', + 'test_jwt_credential_strips_token', + 'test_jwt_credential_rejects_non_string', + 'test_credential_direct_construction_with_jwt', + 'test_credential_direct_construction_with_password', + 'test_credential_rejects_unknown_kwargs', + 'test_credential_hides_internal_details', + 'test_credential_from_callable_with_jwt', + 'test_jwt_credential_repr_redacts_token', + 'test_jwt_credential_rejects_http_endpoint', + 'test_jwt_credential_accepts_https_endpoint', + 'test_password_credential_http_authorization_header', + 'test_password_credential_repr_redacts_password', + 'test_dynamic_auth_sets_header_from_current_credential', + 'test_dynamic_auth_picks_up_rotated_credential', + 'test_set_credential_same_type_updates_state', + 'test_set_credential_password_to_jwt_fails', + 'test_set_credential_jwt_to_password_fails', + 'test_set_credential_failure_does_not_change_state', + ] + + def test_jwt_credential_creation(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _ClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_jwt_credential_strips_token(self) -> None: + cred = Credential.from_jwt(f' {_SAMPLE_JWT}\n') + client = _ClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + @pytest.mark.parametrize( + ('bad_token', 'expected_exc'), + [ + (12345, TypeError), + (None, TypeError), + (b'bytes.jwt.token', TypeError), + (1.5, TypeError), + (['a', 'b'], TypeError), + ('', ValueError), + (' \n', ValueError), + ], + ) + def test_jwt_credential_rejects_non_string(self, bad_token: object, expected_exc: type) -> None: + with pytest.raises(expected_exc): + Credential.from_jwt(bad_token) # type: ignore[arg-type] + + def test_credential_direct_construction_with_jwt(self) -> None: + cred = Credential(jwt_token=f' {_SAMPLE_JWT}\n') + client = _ClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_credential_direct_construction_with_password(self) -> None: + cred = Credential(username='Administrator', password='password') + client = _ClientAdapter('http://localhost', cred) + expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii') + assert _authorization_header(client) == expected + + def test_credential_rejects_unknown_kwargs(self) -> None: + with pytest.raises(TypeError, match='unexpected keyword argument'): + Credential(usernme='Administrator', password='password') # type: ignore[call-arg] + with pytest.raises(TypeError, match='unexpected keyword argument'): + Credential(jwt_token=_SAMPLE_JWT, extra='ignored') # type: ignore[call-arg] + + def test_credential_hides_internal_details(self) -> None: + public_attrs = {name for name in dir(Credential.from_jwt(_SAMPLE_JWT)) if not name.startswith('_')} + assert public_attrs == {'from_callable', 'from_jwt', 'from_username_and_password'} + + def test_credential_from_callable_with_jwt(self) -> None: + cred = Credential.from_callable(lambda: Credential.from_jwt(_SAMPLE_JWT)) + client = _ClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_jwt_credential_repr_redacts_token(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + rendered = repr(cred) + assert _SAMPLE_JWT not in rendered + assert '****' in rendered + + def test_jwt_credential_rejects_http_endpoint(self) -> None: + with pytest.raises(ValueError, match='require a secure'): + _ClientAdapter('http://localhost', Credential.from_jwt(_SAMPLE_JWT)) + + def test_jwt_credential_accepts_https_endpoint(self) -> None: + client = _ClientAdapter('https://localhost', Credential.from_jwt(_SAMPLE_JWT)) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + def test_password_credential_http_authorization_header(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _ClientAdapter('http://localhost', cred) + expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii') + assert _authorization_header(client) == expected + + def test_password_credential_repr_redacts_password(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'super-secret') + rendered = repr(cred) + assert 'super-secret' not in rendered + assert '****' in rendered + assert 'Administrator' in rendered + + def test_dynamic_auth_sets_header_from_current_credential(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _ClientAdapter('https://localhost', cred) + auth = DynamicCredentialAuth(client.credential_holder) + + req = Request('POST', 'https://localhost/api/v1/request') + flow = auth.auth_flow(req) + dispatched = next(flow) + assert dispatched.headers['Authorization'] == f'Bearer {_SAMPLE_JWT}' + + def test_dynamic_auth_picks_up_rotated_credential(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _ClientAdapter('https://localhost', cred) + auth = DynamicCredentialAuth(client.credential_holder) + + # Build a single auth instance, then rotate the underlying credential. + new_token = 'rotated.jwt.token' + client.update_credential(Credential.from_jwt(new_token)) + + req = Request('POST', 'https://localhost/api/v1/request') + flow = auth.auth_flow(req) + dispatched = next(flow) + assert dispatched.headers['Authorization'] == f'Bearer {new_token}' + + def test_set_credential_same_type_updates_state(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _ClientAdapter('https://localhost', cred) + assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}' + + new_token = 'fresh.jwt.token' + client.update_credential(Credential.from_jwt(new_token)) + + assert _authorization_header(client) == f'Bearer {new_token}' + + def test_set_credential_password_to_jwt_fails(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _ClientAdapter('http://localhost', cred) + with pytest.raises(TypeError, match='Cannot switch credential type'): + client.update_credential(Credential.from_jwt(_SAMPLE_JWT)) + + def test_set_credential_jwt_to_password_fails(self) -> None: + cred = Credential.from_jwt(_SAMPLE_JWT) + client = _ClientAdapter('https://localhost', cred) + with pytest.raises(TypeError, match='Cannot switch credential type'): + client.update_credential(Credential.from_username_and_password('Administrator', 'password')) + + def test_set_credential_failure_does_not_change_state(self) -> None: + cred = Credential.from_username_and_password('Administrator', 'password') + client = _ClientAdapter('http://localhost', cred) + original_header = _authorization_header(client) + + with pytest.raises(TypeError): + client.update_credential(Credential.from_jwt(_SAMPLE_JWT)) + + assert _authorization_header(client) == original_header + + +class CredentialTests(CredentialTestSuite): + @pytest.fixture(scope='class', autouse=True) + def validate_test_manifest(self) -> None: + def valid_test_method(meth: str) -> bool: + attr = getattr(CredentialTests, meth) + return callable(attr) and not meth.startswith('__') and meth.startswith('test') + + method_list = [meth for meth in dir(CredentialTests) if valid_test_method(meth)] + test_list = set(CredentialTestSuite.TEST_MANIFEST).symmetric_difference(method_list) + if test_list: + pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.') diff --git a/docs/acouchbase_analytics_api/credential.rst b/docs/acouchbase_analytics_api/credential.rst index 6f6c617..b96a93d 100644 --- a/docs/acouchbase_analytics_api/credential.rst +++ b/docs/acouchbase_analytics_api/credential.rst @@ -11,5 +11,7 @@ Credential .. automethod:: from_username_and_password :no-index: + .. automethod:: from_jwt + :no-index: .. automethod:: from_callable :no-index: diff --git a/docs/couchbase_analytics_api/credential.rst b/docs/couchbase_analytics_api/credential.rst index 47ced06..80b8007 100644 --- a/docs/couchbase_analytics_api/credential.rst +++ b/docs/couchbase_analytics_api/credential.rst @@ -9,4 +9,5 @@ Credential .. autoclass:: Credential .. automethod:: from_username_and_password + .. automethod:: from_jwt .. automethod:: from_callable diff --git a/tests/utils/_async_client_adapter.py b/tests/utils/_async_client_adapter.py index 70aee69..da2ac40 100644 --- a/tests/utils/_async_client_adapter.py +++ b/tests/utils/_async_client_adapter.py @@ -34,6 +34,7 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._cluster_id = adapter._cluster_id self._opts_builder = adapter._opts_builder self._conn_details = adapter._conn_details + self._credential_holder = adapter._credential_holder if self._http_transport_cls is None: self._http_transport_cls = adapter._http_transport_cls diff --git a/tests/utils/_client_adapter.py b/tests/utils/_client_adapter.py index 7e6cbe8..ae63c35 100644 --- a/tests/utils/_client_adapter.py +++ b/tests/utils/_client_adapter.py @@ -35,6 +35,7 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._cluster_id = adapter._cluster_id self._opts_builder = adapter._opts_builder self._conn_details = adapter._conn_details + self._credential_holder = adapter._credential_holder if self._http_transport_cls is None: self._http_transport_cls = adapter._http_transport_cls