Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions acouchbase_analytics/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,22 @@ def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaita
""" # noqa: E501
return self._impl.start_query(statement, *args, **kwargs)

async def set_credential(self, credential: Credential) -> None:
"""Replace the credential used for subsequent HTTP requests.

Allows updating credentials (in particular, rotating a JWT) without restarting
the application. The new credential must be of the same type as the current
credential.

Args:
credential: The new :class:`.Credential` to use.

Raises:
TypeError: If the new credential is a different type than the current
credential.
"""
await self._impl.set_credential(credential)

async def shutdown(self) -> None:
"""Shuts down this cluster instance. Cleaning up all resources associated with it.

Expand Down
1 change: 1 addition & 0 deletions acouchbase_analytics/cluster.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class AsyncCluster:
) -> Awaitable[AsyncQueryHandle]: ...
@overload
def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryHandle]: ...
def set_credential(self, credential: Credential) -> Awaitable[None]: ...
def shutdown(self) -> Awaitable[None]: ...
@overload
@classmethod
Expand Down
23 changes: 19 additions & 4 deletions acouchbase_analytics/protocol/_core/client_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from typing import Optional, cast
from uuid import uuid4

from httpx import URL, AsyncClient, BasicAuth, Response
from httpx import URL, AsyncClient, Response

from couchbase_analytics.common.credential import Credential
from couchbase_analytics.common.credential import Credential, _CredentialHolder
from couchbase_analytics.common.deserializer import Deserializer
from couchbase_analytics.common.logging import LogLevel, log_message
from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth
from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest
from couchbase_analytics.protocol.connection import _ConnectionDetails
from couchbase_analytics.protocol.options import OptionsBuilder
Expand All @@ -47,6 +48,7 @@ def __init__(
self._opts_builder = OptionsBuilder()
kwargs['logger_name'] = self.logger_name
self._conn_details = _ConnectionDetails.create(self._opts_builder, http_endpoint, credential, options, **kwargs)
self._credential_holder = _CredentialHolder(credential)
# PYCO-67: Do we want to allow supporting custom HTTP transports?
self._http_transport_cls = None

Expand Down Expand Up @@ -78,6 +80,13 @@ def connection_details(self) -> _ConnectionDetails:
"""
return self._conn_details

@property
def credential_holder(self) -> _CredentialHolder:
"""
**INTERNAL**
"""
return self._credential_holder

@property
def default_deserializer(self) -> Deserializer:
"""
Expand Down Expand Up @@ -136,6 +145,7 @@ async def create_client(self) -> None:
**INTERNAL**
"""
if not hasattr(self, '_client'):
auth = DynamicCredentialAuth(self._credential_holder)
if self._conn_details.is_secure():
if self._conn_details.ssl_context is None:
raise ValueError('SSL context is required for secure connections.')
Expand All @@ -144,14 +154,14 @@ async def create_client(self) -> None:
transport = self._http_transport_cls(verify=self._conn_details.ssl_context)
self._client = AsyncClient(
verify=self._conn_details.ssl_context,
auth=BasicAuth(*self._conn_details.credential),
auth=auth,
transport=transport,
)
else:
transport = None
if self._http_transport_cls is not None:
transport = self._http_transport_cls()
self._client = AsyncClient(auth=BasicAuth(*self._conn_details.credential), transport=transport)
self._client = AsyncClient(auth=auth, transport=transport)
self.log_message(
(f'Cluster HTTP client created: connection_details={self._conn_details.get_init_details()}'),
LogLevel.INFO,
Expand Down Expand Up @@ -195,5 +205,10 @@ def reset_client(self) -> None:
if hasattr(self, '_client'):
del self._client

async def update_credential(self, new_credential: Credential) -> None:
self._credential_holder.replace(new_credential)
# Future mTLS: await close_client(), rebuild SSL context, await create_client().
self.log_message('Cluster HTTP credential updated', LogLevel.INFO)


logger = logging.getLogger(_AsyncClientAdapter.LOGGER_NAME)
3 changes: 3 additions & 0 deletions acouchbase_analytics/protocol/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ async def shutdown(self) -> None:
else:
self.client_adapter.log_message('Cluster does not have a connection. Ignoring shutdown.', LogLevel.WARNING)

async def set_credential(self, credential: Credential) -> None:
await self._client_adapter.update_credential(credential)

async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQueryResult:
if not self.has_client:
self.client_adapter.log_message(
Expand Down
3 changes: 2 additions & 1 deletion acouchbase_analytics/protocol/cluster.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class AsyncCluster:
def client_adapter(self) -> _AsyncClientAdapter: ...
@property
def connected(self) -> bool: ...
def shutdown(self) -> Awaitable[None]: ...
async def set_credential(self, credential: Credential) -> None: ...
async def shutdown(self) -> None: ...
def database(self, name: str) -> AsyncDatabase: ...
@overload
def execute_query(self, statement: str) -> Awaitable[AsyncQueryResult]: ...
Expand Down
227 changes: 227 additions & 0 deletions acouchbase_analytics/tests/credential_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright 2016-2025. Couchbase, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

from base64 import b64encode
from typing import Any

import pytest
from httpx import Request

from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter
from couchbase_analytics.credential import Credential
from couchbase_analytics.protocol._core.auth import DynamicCredentialAuth

_SAMPLE_JWT = 'eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature'


def _authorization_header(client: Any) -> str:
auth = DynamicCredentialAuth(client.credential_holder)
request_url = f'{client.connection_details.url.get_formatted_url()}{client.analytics_path}'
req = Request('POST', request_url)
flow = auth.auth_flow(req)
dispatched = next(flow)
return dispatched.headers['Authorization']
Comment thread
anirudhlakhotia marked this conversation as resolved.


class CredentialTestSuite:
TEST_MANIFEST = [
'test_jwt_credential_creation',
'test_jwt_credential_strips_token',
'test_jwt_credential_rejects_non_string',
'test_credential_direct_construction_with_jwt',
'test_credential_direct_construction_with_password',
'test_credential_rejects_unknown_kwargs',
'test_credential_hides_internal_details',
'test_credential_from_callable_with_jwt',
'test_jwt_credential_repr_redacts_token',
'test_jwt_credential_rejects_http_endpoint',
'test_jwt_credential_accepts_https_endpoint',
'test_password_credential_http_authorization_header',
'test_password_credential_repr_redacts_password',
'test_dynamic_auth_sets_header_from_current_credential',
'test_async_dynamic_auth_sets_header_from_current_credential',
'test_dynamic_auth_picks_up_rotated_credential',
'test_set_credential_same_type_updates_state',
'test_set_credential_password_to_jwt_fails',
'test_set_credential_jwt_to_password_fails',
'test_set_credential_failure_does_not_change_state',
]

def test_jwt_credential_creation(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

def test_jwt_credential_strips_token(self) -> None:
cred = Credential.from_jwt(f' {_SAMPLE_JWT}\n')
client = _AsyncClientAdapter('https://localhost', cred)
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

@pytest.mark.parametrize(
('bad_token', 'expected_exc'),
[
(12345, TypeError),
(None, TypeError),
(b'bytes.jwt.token', TypeError),
(1.5, TypeError),
(['a', 'b'], TypeError),
('', ValueError),
(' \n', ValueError),
],
)
def test_jwt_credential_rejects_non_string(self, bad_token: object, expected_exc: type) -> None:
with pytest.raises(expected_exc):
Credential.from_jwt(bad_token) # type: ignore[arg-type]

def test_credential_direct_construction_with_jwt(self) -> None:
cred = Credential(jwt_token=f' {_SAMPLE_JWT}\n')
client = _AsyncClientAdapter('https://localhost', cred)
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

def test_credential_direct_construction_with_password(self) -> None:
cred = Credential(username='Administrator', password='password')
client = _AsyncClientAdapter('http://localhost', cred)
expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii')
assert _authorization_header(client) == expected

def test_credential_rejects_unknown_kwargs(self) -> None:
with pytest.raises(TypeError, match='unexpected keyword argument'):
Credential(usernme='Administrator', password='password') # type: ignore[call-arg]
with pytest.raises(TypeError, match='unexpected keyword argument'):
Credential(jwt_token=_SAMPLE_JWT, extra='ignored') # type: ignore[call-arg]

def test_credential_hides_internal_details(self) -> None:
public_attrs = {name for name in dir(Credential.from_jwt(_SAMPLE_JWT)) if not name.startswith('_')}
assert public_attrs == {'from_callable', 'from_jwt', 'from_username_and_password'}

def test_credential_from_callable_with_jwt(self) -> None:
cred = Credential.from_callable(lambda: Credential.from_jwt(_SAMPLE_JWT))
client = _AsyncClientAdapter('https://localhost', cred)
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

def test_jwt_credential_repr_redacts_token(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
rendered = repr(cred)
assert _SAMPLE_JWT not in rendered
assert '****' in rendered

def test_jwt_credential_rejects_http_endpoint(self) -> None:
with pytest.raises(ValueError, match='require a secure'):
_AsyncClientAdapter('http://localhost', Credential.from_jwt(_SAMPLE_JWT))

def test_jwt_credential_accepts_https_endpoint(self) -> None:
client = _AsyncClientAdapter('https://localhost', Credential.from_jwt(_SAMPLE_JWT))
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

def test_password_credential_http_authorization_header(self) -> None:
cred = Credential.from_username_and_password('Administrator', 'password')
client = _AsyncClientAdapter('http://localhost', cred)
expected = 'Basic ' + b64encode(b'Administrator:password').decode('ascii')
assert _authorization_header(client) == expected

def test_password_credential_repr_redacts_password(self) -> None:
cred = Credential.from_username_and_password('Administrator', 'super-secret')
rendered = repr(cred)
assert 'super-secret' not in rendered
assert '****' in rendered
assert 'Administrator' in rendered

def test_dynamic_auth_sets_header_from_current_credential(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
auth = DynamicCredentialAuth(client.credential_holder)

req = Request('POST', 'https://localhost/api/v1/request')
flow = auth.auth_flow(req)
dispatched = next(flow)
assert dispatched.headers['Authorization'] == f'Bearer {_SAMPLE_JWT}'

@pytest.mark.anyio
async def test_async_dynamic_auth_sets_header_from_current_credential(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
auth = DynamicCredentialAuth(client.credential_holder)

request_url = f'{client.connection_details.url.get_formatted_url()}{client.analytics_path}'
req = Request('POST', request_url)
flow = auth.async_auth_flow(req)
dispatched = await flow.__anext__()
assert dispatched.headers['Authorization'] == f'Bearer {_SAMPLE_JWT}'

@pytest.mark.anyio
async def test_dynamic_auth_picks_up_rotated_credential(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
auth = DynamicCredentialAuth(client.credential_holder)

new_token = 'rotated.jwt.token'
await client.update_credential(Credential.from_jwt(new_token))

req = Request('POST', 'https://localhost/api/v1/request')
flow = auth.auth_flow(req)
dispatched = next(flow)
assert dispatched.headers['Authorization'] == f'Bearer {new_token}'

@pytest.mark.anyio
async def test_set_credential_same_type_updates_state(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
assert _authorization_header(client) == f'Bearer {_SAMPLE_JWT}'

new_token = 'fresh.jwt.token'
await client.update_credential(Credential.from_jwt(new_token))

assert _authorization_header(client) == f'Bearer {new_token}'

@pytest.mark.anyio
async def test_set_credential_password_to_jwt_fails(self) -> None:
cred = Credential.from_username_and_password('Administrator', 'password')
client = _AsyncClientAdapter('http://localhost', cred)
with pytest.raises(TypeError, match='Cannot switch credential type'):
await client.update_credential(Credential.from_jwt(_SAMPLE_JWT))

@pytest.mark.anyio
async def test_set_credential_jwt_to_password_fails(self) -> None:
cred = Credential.from_jwt(_SAMPLE_JWT)
client = _AsyncClientAdapter('https://localhost', cred)
with pytest.raises(TypeError, match='Cannot switch credential type'):
await client.update_credential(Credential.from_username_and_password('Administrator', 'password'))

@pytest.mark.anyio
async def test_set_credential_failure_does_not_change_state(self) -> None:
cred = Credential.from_username_and_password('Administrator', 'password')
client = _AsyncClientAdapter('http://localhost', cred)
original_header = _authorization_header(client)

with pytest.raises(TypeError):
await client.update_credential(Credential.from_jwt(_SAMPLE_JWT))

assert _authorization_header(client) == original_header


class CredentialTests(CredentialTestSuite):
@pytest.fixture(scope='class', autouse=True)
def validate_test_manifest(self) -> None:
def valid_test_method(meth: str) -> bool:
attr = getattr(CredentialTests, meth)
return callable(attr) and not meth.startswith('__') and meth.startswith('test')

method_list = [meth for meth in dir(CredentialTests) if valid_test_method(meth)]
test_list = set(CredentialTestSuite.TEST_MANIFEST).symmetric_difference(method_list)
if test_list:
pytest.fail(f'Test manifest invalid. Missing/extra tests: {test_list}.')
2 changes: 2 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@

_UNIT_TESTS = [
'acouchbase_analytics/tests/connection_t.py::ConnectionTests',
'acouchbase_analytics/tests/credential_t.py::CredentialTests',
'acouchbase_analytics/tests/json_parsing_t.py::JsonParsingTests',
'acouchbase_analytics/tests/options_t.py::ClusterOptionsTests',
'acouchbase_analytics/tests/query_options_t.py::ClusterQueryOptionsTests',
'acouchbase_analytics/tests/query_options_t.py::ScopeQueryOptionsTests',
'acouchbase_analytics/tests/test_server_t.py::ClusterTestServerTests',
'acouchbase_analytics/tests/test_server_t.py::ScopeTestServerTests',
'couchbase_analytics/tests/connection_t.py::ConnectionTests',
'couchbase_analytics/tests/credential_t.py::CredentialTests',
'couchbase_analytics/tests/duration_parsing_t.py::DurationParsingTests',
'couchbase_analytics/tests/json_parsing_t.py::JsonParsingTests',
'couchbase_analytics/tests/options_t.py::ClusterOptionsTests',
Expand Down
16 changes: 16 additions & 0 deletions couchbase_analytics/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def start_query(self, statement: str, *args: object, **kwargs: object) -> Blocki
""" # noqa: E501
return self._impl.start_query(statement, *args, **kwargs)

def set_credential(self, credential: Credential) -> None:
"""Replace the credential used for subsequent HTTP requests.

Allows updating credentials (in particular, rotating a JWT) without restarting
the application. The new credential must be of the same type as the current
credential.

Args:
credential: The new :class:`.Credential` to use.

Raises:
TypeError: If the new credential is a different type than the current
credential.
"""
self._impl.set_credential(credential)

def shutdown(self) -> None:
"""Shuts down this cluster instance. Cleaning up all resources associated with it.

Expand Down
1 change: 1 addition & 0 deletions couchbase_analytics/cluster.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Cluster:
) -> BlockingQueryHandle: ...
@overload
def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ...
def set_credential(self, credential: Credential) -> None: ...
def shutdown(self) -> None: ...
@overload
@classmethod
Expand Down
Loading
Loading