diff --git a/conda/conda-recipes/azure-ai-evaluation/meta.yaml b/conda/conda-recipes/azure-ai-evaluation/meta.yaml index c558b0e6b9e3..0a15747b34b6 100644 --- a/conda/conda-recipes/azure-ai-evaluation/meta.yaml +++ b/conda/conda-recipes/azure-ai-evaluation/meta.yaml @@ -10,10 +10,12 @@ source: build: noarch: python number: 0 - script: "{{ PYTHON }} -m pip install . -vv" + script: "{{ PYTHON }} -m pip install . -vv --no-build-isolation" requirements: host: + - setuptools <70 + - wheel - azure-core >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - azure-identity >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - msrest >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} diff --git a/conda/conda-recipes/azure-ai-ml/meta.yaml b/conda/conda-recipes/azure-ai-ml/meta.yaml index b9cad7784dd5..cf04bb304bc4 100644 --- a/conda/conda-recipes/azure-ai-ml/meta.yaml +++ b/conda/conda-recipes/azure-ai-ml/meta.yaml @@ -10,10 +10,12 @@ source: build: noarch: python number: 0 - script: "{{ PYTHON }} -m pip install . -vv" + script: "{{ PYTHON }} -m pip install . -vv --no-build-isolation" requirements: host: + - setuptools <70 + - wheel - azure-core >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - azure-identity >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - msrest >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} diff --git a/eng/pipelines/templates/stages/conda-sdk-client.yml b/eng/pipelines/templates/stages/conda-sdk-client.yml index 5e5509a64fa2..c4d255b5e49b 100644 --- a/eng/pipelines/templates/stages/conda-sdk-client.yml +++ b/eng/pipelines/templates/stages/conda-sdk-client.yml @@ -256,7 +256,7 @@ extends: service: ml in_batch: ${{ parameters.release_azure_ai_ml }} channels: - - conda-forge + - https://prefix.dev/conda-forge checkout: - package: azure-ai-ml version: 1.28.1 diff --git a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py index 83bfed097991..f581049981f4 100644 --- a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py +++ b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py @@ -37,7 +37,7 @@ CONDA_ENV_FILE = """name: azure-build-env channels: - - conda-forge + - https://prefix.dev/conda-forge - defaults dependencies: - python=3.10 @@ -294,7 +294,7 @@ def create_combined_sdist( [ os.path.join(config_assembled_folder, a) for a in os.listdir(config_assembled_folder) - if os.path.isfile(os.path.join(config_assembled_folder, a)) and conda_build.name in a + if os.path.isfile(os.path.join(config_assembled_folder, a)) and conda_build.name.replace("-", "_") in a ] ) ) diff --git a/eng/tools/azure-sdk-tools/ci_tools/functions.py b/eng/tools/azure-sdk-tools/ci_tools/functions.py index 2ef7c99779f0..4058652af072 100644 --- a/eng/tools/azure-sdk-tools/ci_tools/functions.py +++ b/eng/tools/azure-sdk-tools/ci_tools/functions.py @@ -260,8 +260,12 @@ def discover_targeted_packages( for pkg in collected_packages: try: parsed_packages.append(ParsedSetup.from_path(pkg)) - except RuntimeError as e: - logging.error(f"Unable to parse metadata for package {pkg}, omitting from build.") + except Exception as e: + # Some packages have setup.py files that import modules unavailable in the + # current environment (e.g. pkg_resources removed by setuptools>=80). Such + # packages should be omitted from the build/regression set rather than + # aborting discovery for the entire repo. + logging.error(f"Unable to parse metadata for package {pkg}, omitting from build. Reason: {e}") continue # filter for compatibility, this means excluding a package that doesn't support py36 when we are running a py36 executable diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 8dd5d718b51a..1db2cb11fa9e 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,12 @@ ## Release History +### 4.14.7 (Unreleased) + +#### Bugs Fixed +* Fixed `SELECT VALUE` aggregation classification across partitions: booleans are no longer treated as numeric aggregates, non-aggregate numeric projections are no longer merged, and `MIN`/`MAX` detection is now correct. See [PR 46692](https://github.com/Azure/azure-sdk-for-python/pull/46692) +* Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. See [PR 46692](https://github.com/Azure/azure-sdk-for-python/pull/46692) +* Fixed bug where unavailable regional endpoints were dropped from the routing list instead of being kept as fallback options. See [PR 45200](https://github.com/Azure/azure-sdk-for-python/pull/45200) + ### 4.14.6 (2026-02-02) #### Bugs Fixed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index c20a879eff31..ffab91d5dc9c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -1,4 +1,4 @@ -# The MIT License (MIT) +# The MIT License (MIT) # Copyright (c) 2014 Microsoft Corporation # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -39,6 +39,11 @@ from . import documents from . import http_constants from . import _runtime_constants +from ._query_aggregate_utils import ( + _AggregatePartialClassification, + _classify_aggregate_partial, + _get_select_value_aggregate_function, +) from ._constants import _Constants as Constants from .auth import _get_authorization_header from .offer import ThroughputProperties @@ -124,6 +129,7 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]: options['accessCondition'] = {'type': 'IfNoneMatch', 'condition': if_none_match} return options + def _merge_query_results( results: dict[str, Any], partial_result: dict[str, Any], @@ -163,22 +169,13 @@ def _merge_query_results( results_docs = results.get("Documents") - # Check if both results are aggregate queries - is_partial_agg = ( - isinstance(partial_docs, list) - and len(partial_docs) == 1 - and isinstance(partial_docs[0], dict) - and partial_docs[0].get("_aggregate") is not None - ) - is_results_agg = ( - results_docs - and isinstance(results_docs, list) - and len(results_docs) == 1 - and isinstance(results_docs[0], dict) - and results_docs[0].get("_aggregate") is not None - ) + partial_aggregate_class = _classify_aggregate_partial(partial_docs, query) + results_aggregate_class = _classify_aggregate_partial(results_docs, query) - if is_partial_agg and is_results_agg: + if ( + partial_aggregate_class == _AggregatePartialClassification.OBJECT + and results_aggregate_class == _AggregatePartialClassification.OBJECT + ): agg_results = results_docs[0]["_aggregate"] # type: ignore[index] agg_partial = partial_docs[0]["_aggregate"] for key in agg_partial: @@ -196,33 +193,26 @@ def _merge_query_results( agg_results[key] += agg_partial[key] return results - # Check if both are VALUE aggregate queries - is_partial_value_agg = ( - isinstance(partial_docs, list) - and len(partial_docs) == 1 - and isinstance(partial_docs[0], (int, float)) - ) - is_results_value_agg = ( - results_docs - and isinstance(results_docs, list) - and len(results_docs) == 1 - and isinstance(results_docs[0], (int, float)) - ) - - if is_partial_value_agg and is_results_value_agg: - query_text = query.get("query") if isinstance(query, dict) else query - if query_text: - query_upper = query_text.upper() - # For MIN/MAX, we find the min/max of the partial results. - # For COUNT/SUM, we sum the partial results. - # Without robust query parsing, we can't distinguish them reliably. - # Defaulting to sum for COUNT/SUM. MIN/MAX VALUE queries are not fully supported client-side. - if " SELECT VALUE MIN" in query_upper: - results_docs[0] = min(results_docs[0], partial_docs[0]) # type: ignore[index] - elif " SELECT VALUE MAX" in query_upper: - results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] - else: # For COUNT/SUM, we sum the partial results - results_docs[0] += partial_docs[0] # type: ignore[index] + if ( + partial_aggregate_class == _AggregatePartialClassification.VALUE + and results_aggregate_class == _AggregatePartialClassification.VALUE + ): + aggregate_fn = _get_select_value_aggregate_function(query) + if aggregate_fn is None: + raise ValueError( + "Invariant violation: VALUE aggregate classification requires a recognized aggregate function." + ) + if aggregate_fn == "MIN": + results_docs[0] = min(results_docs[0], partial_docs[0]) # type: ignore[index] + elif aggregate_fn == "MAX": + results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] + elif aggregate_fn == "AVG": + raise ValueError( + "VALUE AVG aggregate merge across partitions is not supported client-side." + ) + else: + # COUNT/SUM are additive. + results_docs[0] += partial_docs[0] # type: ignore[index] return results # Standard query, append documents @@ -234,6 +224,29 @@ def _merge_query_results( return results +def _raise_query_merge_value_error(merge_error: ValueError) -> None: + """Raise a clearer user-facing error for unsupported VALUE aggregate merges. + + ``SELECT VALUE AVG(...)`` partials cannot be merged correctly client-side + across multiple partition/range responses. We fail loudly instead of + falling back to list concatenation (which would silently produce + mathematically incorrect results). + + :param merge_error: ValueError raised while merging partial query results. + :type merge_error: ValueError + :raises ValueError: Always re-raises, potentially with a clearer message. + """ + merge_message = str(merge_error) + if "VALUE AVG aggregate merge across partitions is not supported client-side." in merge_message: + raise ValueError( + "Unsupported query shape for range-scoped pagination: " + "SELECT VALUE AVG(...) cannot be merged client-side when the query " + "scope spans multiple physical partitions." + ) from merge_error + raise merge_error + + + def GetHeaders( # pylint: disable=too-many-statements,too-many-branches cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], default_headers: Mapping[str, Any], diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index fa1f6efdced6..0682f3e05e7b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1,4 +1,4 @@ -# The MIT License (MIT) +# The MIT License (MIT) # Copyright (c) 2014 Microsoft Corporation # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,11 +23,12 @@ """Document client class for the Azure Cosmos database service. """ +import logging import os import urllib.parse import uuid from concurrent.futures.thread import ThreadPoolExecutor -from typing import Callable, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast from typing_extensions import TypedDict from urllib3.util.retry import Retry @@ -69,10 +70,26 @@ from ._read_items_helper import ReadItemsHelperSync from ._request_object import RequestObject from ._retry_utility import ConnectionRetryPolicy -from ._routing import routing_map_provider, routing_range +from ._routing import routing_map_provider +from ._routing.feed_range_continuation import ( + _FeedRangePaginationState, + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _apply_feedrange_request_headers, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _increment_explode_iterations_or_raise, + _normalize_max_item_count, + _should_bridge_legacy_continuation, + _update_no_progress_page_count, + _validate_token_identity, + _write_query_outbound_continuation, +) from ._inference_service import _InferenceService from .documents import ConnectionPolicy, DatabaseAccount from .partition_key import ( + _build_partition_key_from_properties, _Undefined, _Empty, _PartitionKeyKind, @@ -81,6 +98,8 @@ _return_undefined_or_empty_partition_key, ) +_LOGGER = logging.getLogger(__name__) + class CredentialDict(TypedDict, total=False): masterKey: str resourceTokens: Mapping[str, Any] @@ -139,6 +158,7 @@ def __init__( # pylint: disable=too-many-statements self.client_id = str(uuid.uuid4()) self.url_connection = url_connection self.master_key: Optional[str] = None + self.resource_tokens: Optional[Mapping[str, Any]] = None self.aad_credentials: Optional[TokenCredential] = None if auth is not None: @@ -3145,6 +3165,7 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma partition_key_range_id: Optional[str] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, is_query_plan: bool = False, + response_headers_list: Optional[list[CaseInsensitiveDict]] = None, **kwargs: Any ) -> Tuple[list[dict[str, Any]], CaseInsensitiveDict]: """Query for more than one Azure Cosmos resources. @@ -3163,6 +3184,9 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma :type response_hook: Callable[[Mapping[str, Any], dict[str, Any]], None] :param bool is_query_plan: Specifies if the call is to fetch query plan + :param response_headers_list: + Optional list to which per-request response headers will be appended. + :type response_headers_list: Optional[list[CaseInsensitiveDict]] :returns: A list of the queried resources. :rtype: list :raises SystemError: If the query compatibility mode is undefined. @@ -3182,6 +3206,29 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma if timeout is not None: kwargs.setdefault("timeout", timeout) + # The capture dict can arrive via two upstream paths: + # 1. The query execution context puts it into ``options`` (the + # common case for query pagination — see + # ``_QueryExecutionContextBase._fetch_items_helper_no_retries``). + # 2. ``routing_map_provider.get_routing_map`` puts it into + # ``kwargs`` for PK-range fetches. + # Honour both so checkpoint-on-failure works on every path. + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + if internal_headers_capture is None and isinstance(options, dict): + internal_headers_capture = options.pop( + "_internal_response_headers_capture", None + ) + + def _capture_internal_headers(headers: Mapping[str, Any]) -> None: + # Local helper so flow analysis can narrow Optional[Dict] once + # and every call site stays a single line. + if internal_headers_capture is None: + return + internal_headers_capture.clear() + internal_headers_capture.update(headers) + if query: __GetBodiesFromQueryResult = result_fn else: @@ -3227,12 +3274,15 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: change_feed_state.populate_request_headers(self._routing_map_provider, headers, feed_options) request_params.headers = headers - result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) - self.last_response_headers = last_response_headers + result, get_response_headers = self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = get_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(get_response_headers) + if response_headers_list is not None: + response_headers_list.append(get_response_headers.copy()) if response_hook: - response_hook(last_response_headers, result) - return __GetBodiesFromQueryResult(result), last_response_headers - + response_hook(get_response_headers, result) + return __GetBodiesFromQueryResult(result), get_response_headers query = self.__CheckAndUnifyQueryFormat(query) if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, @@ -3264,8 +3314,10 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: req_headers[http_constants.HttpHeaders.IsQuery] = "true" base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id) - # Check if the over lapping ranges can be populated + # Check if the overlapping ranges can be populated feed_range_epk = None + container_properties = kwargs.pop("container_properties", None) + is_full_pk_structured_scope = False if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() @@ -3274,73 +3326,328 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: prefix_partition_key_value: _SequentialPartitionKeyType = kwargs.pop("prefix_partition_key_value") feed_range_epk = ( prefix_partition_key_obj._get_epk_range_for_prefix_partition_key(prefix_partition_key_value)) + elif options.get("partitionKey") is not None and container_properties is not None: + partition_key_value = options["partitionKey"] + partition_key_obj = _build_partition_key_from_properties(container_properties) + if not partition_key_obj._is_prefix_partition_key(partition_key_value): + # Once we route full-PK queries through feed-range pagination, + # avoid sending the legacy partition-key header on the same request. + req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) + # Full-PK returns a single-value inclusive range; normalize to + # [min, max) before routing-map overlap resolution. + feed_range_epk = partition_key_obj._get_epk_range_for_partition_key( + partition_key_value + ).to_normalized_range() + is_full_pk_structured_scope = True # If feed_range_epk exist, query with the range if feed_range_epk is not None: - over_lapping_ranges = self._routing_map_provider.get_overlapping_ranges(resource_id, [feed_range_epk], - options) - # It is possible to get more than one over lapping range. We need to get the query results for each one - results: dict[str, Any] = {} - # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over - # lapping physical partition. The EPK sub range will be one of four: - # 1) Will have a range min equal to the feed range EPK min, and a range max equal to the over lapping - # partition - # 2) Will have a range min equal to the over lapping partition range min, and a range max equal to the - # feed range EPK range max. - # 3) will match exactly with the current over lapping physical partition, so we just return the over lapping - # physical partition's partition key id. - # 4) Will equal the feed range EPK since it is a sub range of a single physical partition - for over_lapping_range in over_lapping_ranges: - single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) - # Since the range min and max are all Upper Cased string Hex Values, - # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), - range_max=min(single_range.max, feed_range_epk.max), - isMinInclusive=True, isMaxInclusive=False) - if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: - # The Epk Sub Range spans exactly one physical partition - # In this case we can route to the physical pk range id - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] + if resource_id is None: + raise ValueError("resource_id is required for feed_range continuation.") + # The None-check above already narrows ``resource_id`` to ``str`` + # for the rest of this block. Bind it to a clearly-named local so + # the feed_range helpers below read as ``resource_id_str`` instead + # of the generic ``resource_id``. + resource_id_str: str = resource_id + # Decode and validate inbound continuation for this request. + # ``None`` means start from the beginning of the requested + # feed range. + page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) + # Single shared copy of options for routing-map lookups in this call. + # ``get_overlapping_ranges`` does not mutate options; copying once + # avoids per-iteration ``dict(options)`` allocations. + routing_options = dict(options) + inbound_serialized_continuation = options.get("continuation") + inbound_token_payload = _decode_token(inbound_serialized_continuation) + legacy_bridge_in_use = False + # Cache for the input scope's single-partition classification. + # We compute it at most once per __QueryFeed call so inbound + # bridge-detection, mid-page checkpoint, and end-of-page outbound + # writer all agree even if the PK range cache refreshes mid-call. + cached_is_single_partition: Optional[bool] = None + + def _is_input_scope_single_partition() -> bool: + """Return True when the caller input range currently maps to one physical partition. + + Result is cached for the duration of this __QueryFeed call. + + :returns: True if the input scope maps to a single physical partition. + :rtype: bool + """ + nonlocal cached_is_single_partition + if cached_is_single_partition is None: + scope_overlaps = self._routing_map_provider.get_overlapping_ranges( + resource_id, [feed_range_epk], routing_options + ) + cached_is_single_partition = ( + len(_derive_initial_feedranges(feed_range_epk, scope_overlaps)) == 1 + ) + return cached_is_single_partition + + if inbound_serialized_continuation and inbound_token_payload is None: + scope_is_single_partition = False + if not is_full_pk_structured_scope: + scope_is_single_partition = _is_input_scope_single_partition() + if _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + scope_is_single_partition, + ): + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + else: + _LOGGER.warning( + "Feed-range query continuation token is in legacy format and the input scope " + "currently maps to one physical partition; honoring legacy continuation for resume." + ) + legacy_bridge_in_use = True + else: + _LOGGER.warning( + "Feed-range query continuation token is not in the supported structured format; " + "restarting this feed_range query from the beginning." + ) + if inbound_token_payload is not None: + _validate_token_identity( + inbound_token_payload, + resource_id_str, + query, + feed_range_epk, + ) + pagination_state = _FeedRangePaginationState.from_inbound( + inbound_token_payload, page_size_hint + ) + elif legacy_bridge_in_use and inbound_serialized_continuation: + pagination_state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + feed_range_epk, + inbound_serialized_continuation, + page_size_hint, + ) + else: + # First call. Ask the routing map which + # partitions the input feed_range overlaps right now and turn + # each overlap into a feedrange (intersection of that partition + # and the input feed_range). + first_overlaps = self._routing_map_provider.get_overlapping_ranges( + resource_id, [feed_range_epk], routing_options + ) + all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) + if not all_feedranges: + # The input feed_range overlaps no current physical + # partition. Fall through to the regular __Post path + # below so cross-partition gating can still surface + # (e.g. an empty ``partition_key=[]`` prefix with + # ``enableCrossPartitionQuery=False`` must raise + # ``BAD_REQUEST``). + pagination_state = None else: - # The Epk Sub Range spans less than a single physical partition - # In this case we route to the physical partition and - # pass the epk sub range to the headers to filter within partition - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] - req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min - req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max - req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, last_response_headers = self.__Post( - path, request_params, query, req_headers, **kwargs + pagination_state = _FeedRangePaginationState.from_derived_feedranges( + all_feedranges, + page_size_hint, + ) + + if pagination_state is not None: + results: dict[str, Any] = {} + feedrange_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + consecutive_no_progress_pages = 0 + + def _checkpoint_and_reraise(error: Exception) -> None: + # Intentionally broad: stamp the latest resumable checkpoint + # for any mid-page failure, then re-raise the original error. + self.last_response_headers = feedrange_response_headers + try: + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), + ) + except Exception as continuation_write_error: # pylint: disable=broad-exception-caught + _LOGGER.warning( + "Failed to write continuation while handling query POST failure: %s", + continuation_write_error, + ) + raise error + + # NOTE: Keep this feed_range pagination loop in sync with + # ``azure/cosmos/aio/_cosmos_client_connection_async.py::__QueryFeed``. + while pagination_state.can_issue_request(): + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + + # Wrap all mid-page work that can raise (routing lookups, + # scope build/explode, backend POST, and result merge) so + # we always stamp a resumable checkpoint into + # last_response_headers[Continuation] before re-raising. + # Post-result accounting below is pure local bookkeeping + # and is intentionally left outside this try. + try: + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. + overlapping = self._routing_map_provider.get_overlapping_ranges( + resource_id, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = self._routing_map_provider.get_overlapping_ranges( + resource_id, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + base.set_session_token_header( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) + + backend_query_result, backend_response_headers = self.__Post( + path, request_params, query, req_headers, **kwargs + ) + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, + ) + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + except Exception as mid_page_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(mid_page_error) + + previous_feedrange = pagination_state.head_range + previous_backend_continuation = pagination_state.head_bc + page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + pagination_state.apply_post_result( + page_items_returned, + backend_response_headers.get(http_constants.HttpHeaders.Continuation), + ) + consecutive_no_progress_pages = _update_no_progress_page_count( + consecutive_no_progress_pages, + page_items_returned, + previous_feedrange, + previous_backend_continuation, + pagination_state.head_range, + pagination_state.head_bc, + ) + if ( + consecutive_no_progress_pages >= _MAX_CONSECUTIVE_NO_PROGRESS_PAGES + and consecutive_no_progress_pages % _MAX_CONSECUTIVE_NO_PROGRESS_PAGES == 0 + ): + # Warning-only: do not fail fast here. + current_head = pagination_state.head_range + head_min = current_head.min if current_head else "" + head_max = current_head.max if current_head else "" + _LOGGER.warning( + "Feed-range query has returned 0 items for %s consecutive continuation pages " + "with the same continuation token and partition key range [%s, %s); continuing scan.", + consecutive_no_progress_pages, + head_min, + head_max, + ) + + # maxItemCount is a per-request hint. Return this SDK page + # after the first non-empty logical result instead of filling + # an exact target count by issuing extra backend requests. + if page_items_returned > 0: + break + + # Pagination loop is done — write the final outbound + # continuation (or clear the header if the queue is fully + # drained) so the caller's ``by_page`` loop terminates. + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), ) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) - # Introducing a temporary complex function into a critical path to handle aggregated queries - # during splits, as a precaution falling back to the original logic if anything goes wrong - try: - results = base._merge_query_results(results, partial_result, query) - except Exception: # pylint: disable=broad-exception-caught - # If the new merge logic fails, fall back to the original logic. - if results: - results["Documents"].extend(partial_result["Documents"]) - else: - results = partial_result - if response_hook: - response_hook(last_response_headers, partial_result) - # if the prefix partition query has results lets return it - if results: - return __GetBodiesFromQueryResult(results), last_response_headers - - result, last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, result, last_response_headers) - if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + # End feed_range pagination block. + self.last_response_headers = feedrange_response_headers + + # if the prefix partition query has results lets return it + if results: + if feedrange_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + index_metrics_raw = feedrange_response_headers[http_constants.HttpHeaders.IndexUtilization] + feedrange_response_headers[http_constants.HttpHeaders.IndexUtilization] = ( + _utils.get_index_metrics_info(index_metrics_raw)) + return __GetBodiesFromQueryResult(results), feedrange_response_headers + return [], feedrange_response_headers + + result, post_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = post_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(post_response_headers) + self._UpdateSessionIfRequired(req_headers, result, post_response_headers) + if post_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization - index_metrics_raw = last_response_headers[INDEX_METRICS_HEADER] - last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + index_metrics_raw = post_response_headers[INDEX_METRICS_HEADER] + post_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + if response_headers_list is not None: + response_headers_list.append(post_response_headers.copy()) if response_hook: - response_hook(last_response_headers, result) + response_hook(post_response_headers, result) - return __GetBodiesFromQueryResult(result), last_response_headers + return __GetBodiesFromQueryResult(result), post_response_headers def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_locations: Optional[Sequence[str]] = None, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py index 1a866c2df873..e57dffbccc06 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py @@ -51,6 +51,10 @@ def __init__(self, client, options): self._has_started = False self._has_finished = False self._buffer = deque() + self._resource_link = None + # Per-query mutable capture used by __QueryFeed to report response + # headers (including failure checkpoints) without crossing requests. + self._internal_response_headers_capture = {} def _get_initial_continuation(self): if "continuation" in self._options: @@ -118,6 +122,9 @@ async def _fetch_items_helper_no_retries(self, fetch_function): """ fetched_items = [] new_options = copy.deepcopy(self._options) + # Clear stale values from prior pages before issuing a new fetch. + self._internal_response_headers_capture.clear() + new_options["_internal_response_headers_capture"] = self._internal_response_headers_capture while self._continuation or not self._has_started: new_options["continuation"] = self._continuation @@ -184,9 +191,13 @@ async def callback(**kwargs): # pylint: disable=unused-argument # Refresh routing map to get new partition key ranges self._client.refresh_routing_map_provider() - # Reset execution context state to allow retry from the beginning + + # Reset execution context state for retry. If __QueryFeed already + # stamped a checkpoint continuation on failure, resume from it. + continuation_key = http_constants.HttpHeaders.Continuation + checkpoint_continuation = self._internal_response_headers_capture.get(continuation_key) self._has_started = False - self._continuation = None + self._continuation = checkpoint_continuation # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py index 528ca87f2586..a38be4b21df6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py @@ -49,6 +49,10 @@ def __init__(self, client, options): self._has_started = False self._has_finished = False self._buffer = deque() + self._resource_link = None + # Per-query mutable capture used by __QueryFeed to report response + # headers (including failure checkpoints) without crossing requests. + self._internal_response_headers_capture = {} def _get_initial_continuation(self): if "continuation" in self._options: @@ -116,6 +120,9 @@ def _fetch_items_helper_no_retries(self, fetch_function): """ fetched_items = [] new_options = copy.deepcopy(self._options) + # Clear stale values from prior pages before issuing a new fetch. + self._internal_response_headers_capture.clear() + new_options["_internal_response_headers_capture"] = self._internal_response_headers_capture while self._continuation or not self._has_started: new_options["continuation"] = self._continuation @@ -180,11 +187,15 @@ def callback(**kwargs): # pylint: disable=unused-argument max_retries ) - # Refresh routing map to get new partition key ranges + + # Refresh routing map to get new partition key ranges. self._client.refresh_routing_map_provider() - # Reset execution context state to allow retry from the beginning + # Reset execution context state for retry. If __QueryFeed already + # stamped a checkpoint continuation on failure, resume from it. + continuation_key = http_constants.HttpHeaders.Continuation + checkpoint_continuation = self._internal_response_headers_capture.get(continuation_key) self._has_started = False - self._continuation = None + self._continuation = checkpoint_continuation # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index cf8239488712..dd45279b6f59 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -513,8 +513,11 @@ def get_preferred_regional_routing_contexts( else: regional_endpoints.append(regional_endpoint) - # If all preferred locations are unavailable, honor the preferred list by trying them anyway. - if not regional_endpoints and unavailable_endpoints: + # Always append unavailable endpoints to the end of the list so they can be + # used as a last resort. This ensures that when all healthy endpoints are filtered + # out (e.g., by excluded_locations), the SDK can still fall back to unavailable + # regional endpoints rather than the global endpoint. + if unavailable_endpoints: regional_endpoints.extend(unavailable_endpoints) # If there are no preferred locations or none of the preferred locations are in the account, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py new file mode 100644 index 000000000000..0cae17919e39 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py @@ -0,0 +1,270 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +from enum import Enum +from typing import Any, Optional, Union + + +# Used by query paging and query merge paths to decide whether a row is +# a normal row or part of an aggregate result. +class _AggregatePartialClassification(Enum): + """Classification for one-partition query partial payloads.""" + + NONE = "none" + OBJECT = "object" + VALUE = "value" + + +def _extract_query_text(query: Optional[Union[str, dict[str, Any]]]) -> Optional[str]: + """Extract SQL text from a string or query-spec dictionary. + + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: Query text when present; otherwise ``None``. + :rtype: Optional[str] + """ + if isinstance(query, str): + return query + if isinstance(query, dict): + query_text = query.get("query") + if isinstance(query_text, str): + return query_text + return None + + +def _strip_sql_block_comments(query_text: str) -> str: + """Return ``query_text`` with ``/* ... */`` comment spans removed. + + The aggregate detector is a lightweight scanner, so this helper keeps the + same lightweight approach and removes only block comments before scanning. + Quoted strings are preserved so comment-like text inside literals does not + get stripped. + + :param query_text: Raw query text. + :type query_text: str + :returns: Query text with block comments removed. + :rtype: str + """ + out: list[str] = [] + index = 0 + length = len(query_text) + in_quote: Optional[str] = None + + while index < length: + ch = query_text[index] + + if in_quote is not None: + out.append(ch) + # SQL-style escaped quote inside same quote type, e.g. 'it''s'. + if ch == in_quote and index + 1 < length and query_text[index + 1] == in_quote: + out.append(query_text[index + 1]) + index += 2 + continue + if ch == in_quote: + in_quote = None + index += 1 + continue + + if ch in ("'", '"'): + in_quote = ch + out.append(ch) + index += 1 + continue + + if ch == "/" and index + 1 < length and query_text[index + 1] == "*": + index += 2 + while index + 1 < length and not (query_text[index] == "*" and query_text[index + 1] == "/"): + index += 1 + if index + 1 < length: + index += 2 + # Preserve token separation where a comment was removed. + out.append(" ") + continue + + out.append(ch) + index += 1 + + return "".join(out) + + +def _get_select_value_aggregate_function(query: Optional[Union[str, dict[str, Any]]]) -> Optional[str]: + """Identify the aggregate function for ``SELECT VALUE`` aggregate queries. + + This is a lightweight text heuristic (not a SQL parser). It extracts only + the OUTER ``SELECT VALUE`` projection and then matches aggregate function + names in that projection so nested subqueries do not drive outer + classification. + + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: One of ``COUNT``, ``SUM``, ``MIN``, ``MAX``, ``AVG`` when matched; otherwise ``None``. + :rtype: Optional[str] + """ + query_text = _extract_query_text(query) + if not query_text: + return None + + without_comments = _strip_sql_block_comments(query_text) + normalized = " ".join(without_comments.upper().split()) + projection = _extract_outer_select_value_projection(normalized) + if projection is None: + return None + + projection = _unwrap_outer_parentheses(projection) + # A projection-level subquery should not classify as an outer VALUE aggregate. + if projection.startswith("SELECT VALUE "): + return None + + return _find_top_level_aggregate_function(projection) + + +def _find_top_level_aggregate_function(projection: str) -> Optional[str]: + """Return an aggregate function name only when it appears at the top level. + + This prevents nested projection expressions (for example ARRAY(SELECT VALUE + COUNT(...))) from being misclassified as outer VALUE aggregates. + + :param projection: SELECT VALUE projection text to inspect. + :type projection: str + :returns: Aggregate function name when matched at top level; otherwise ``None``. + :rtype: Optional[str] + """ + aggregate_fns = {"COUNT", "SUM", "MIN", "MAX", "AVG"} + depth = 0 + index = 0 + length = len(projection) + + while index < length: + ch = projection[index] + if ch == "(": + depth += 1 + index += 1 + continue + if ch == ")": + if depth > 0: + depth -= 1 + index += 1 + continue + + if depth == 0 and (ch.isalpha() or ch == "_"): + start = index + index += 1 + while index < length and (projection[index].isalnum() or projection[index] == "_"): + index += 1 + token = projection[start:index] + + if token in aggregate_fns: + lookahead = index + while lookahead < length and projection[lookahead].isspace(): + lookahead += 1 + if lookahead < length and projection[lookahead] == "(": + return token + continue + + index += 1 + + return None + + +def _unwrap_outer_parentheses(text: str) -> str: + """Strip redundant outer parentheses while preserving inner structure. + + :param text: Projection text to normalize. + :type text: str + :returns: Projection text with only redundant outer parentheses removed. + :rtype: str + """ + candidate = text.strip() + while candidate.startswith("(") and candidate.endswith(")"): + depth = 0 + balanced = True + outer_pair = False + for idx, char in enumerate(candidate): + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + if depth < 0: + balanced = False + break + # Closing the opening '(' at index 0 means we found the outer pair. + if depth == 0: + outer_pair = idx == len(candidate) - 1 + break + if not balanced or not outer_pair: + break + candidate = candidate[1:-1].strip() + return candidate + + +def _extract_outer_select_value_projection(normalized_query: str) -> Optional[str]: + """Return the outer ``SELECT VALUE`` projection text up to the outer ``FROM``. + + Uses a lightweight parenthesis-depth scan so nested subqueries do not + influence outer aggregate detection. + + :param normalized_query: Uppercased, whitespace-normalized query text. + :type normalized_query: str + :returns: Outer ``SELECT VALUE`` projection when found; otherwise ``None``. + :rtype: Optional[str] + """ + select_value = "SELECT VALUE" + # Minimal hardening: only classify when the OUTER query starts with + # SELECT VALUE. This avoids matching nested SELECT VALUE occurrences. + if not normalized_query.startswith(select_value): + return None + start_idx = 0 + + projection_start = start_idx + len(select_value) + if projection_start < len(normalized_query) and normalized_query[projection_start] == " ": + projection_start += 1 + + depth = 0 + index = projection_start + while index <= len(normalized_query) - 4: + ch = normalized_query[index] + if ch == "(": + depth += 1 + elif ch == ")" and depth > 0: + depth -= 1 + + if depth == 0 and normalized_query[index:index + 4] == "FROM": + prev_char = normalized_query[index - 1] if index > 0 else " " + next_char = normalized_query[index + 4] if index + 4 < len(normalized_query) else " " + if not (prev_char.isalnum() or prev_char == "_") and not (next_char.isalnum() or next_char == "_"): + projection = normalized_query[projection_start:index].strip() + return projection or None + index += 1 + + return None + + +def _classify_aggregate_partial( + docs: Any, + query: Optional[Union[str, dict[str, Any]]] +) -> _AggregatePartialClassification: + """Classify whether a partial result row is part of an aggregate result. + + :param docs: Partial ``Documents`` payload from one backend response. + :type docs: Any + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: Aggregate partial classification. + :rtype: _AggregatePartialClassification + """ + if not isinstance(docs, list) or len(docs) != 1: + return _AggregatePartialClassification.NONE + + row = docs[0] + if isinstance(row, dict) and row.get("_aggregate") is not None: + return _AggregatePartialClassification.OBJECT + + # bool is intentionally excluded: VALUE-aggregate merge semantics are numeric. + if isinstance(row, (int, float)) and not isinstance(row, bool): + if _get_select_value_aggregate_function(query) is not None: + return _AggregatePartialClassification.VALUE + + return _AggregatePartialClassification.NONE diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py new file mode 100644 index 000000000000..e2e4462d8a04 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py @@ -0,0 +1,889 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Shared helpers for the structured ``feed_range`` continuation token. + +Both sync and async ``__QueryFeed`` implementations use this module for +token wire format, request fingerprinting, and feed-range routing helpers. + +The token stores an ordered ``c`` list of ``{min, max, bc}`` entries. +Pagination reads and updates the queue head, then advances when the head +is drained. +""" + +import base64 +import binascii +import json +from collections import deque +from typing import Any, Deque, Iterable, List, MutableMapping, Optional, Tuple + +from .. import http_constants +from .._cosmos_integers import _UInt128 +from .._cosmos_murmurhash3 import murmurhash3_128 +from .._query_aggregate_utils import _AggregatePartialClassification, _classify_aggregate_partial +from . import routing_range + + +# ----- Token wire-format constants --------------------------------------- +# Field codes for the v=1 envelope. +_TOKEN_VERSION = 1 +# Token schema version so decoders can reject unknown envelope shapes. +_FIELD_VERSION = "v" +# Resource ID for the container that originally produced this token. +_FIELD_COLLECTION_RID = "cr" +# Fingerprint of query text + parameter values to prevent wrong-query resume. +_FIELD_QUERY_HASH = "qh" +# Fingerprint of the caller's input feed_range to prevent wrong-scope resume. +_FIELD_FEEDRANGE_HASH = "frh" +# Ordered list of {min, max, bc} entries for the requested feed range. +# Iteration state comes from the list order; there is no separate +# top-level "current" field. +_FIELD_CONTINUATIONS = "c" +# Backend continuation for ONE entry. Lives INSIDE each ``c[i]`` entry, +# never at the envelope level. ``null`` means "this sub-range has not +# been started, or has been fully drained". +_FIELD_BACKEND_CONTINUATION = "bc" +# Observability threshold for repeated empty pages with no continuation/feedrange movement. +# This is warning-only (not a hard stop); pagination continues until the queue drains. +_MAX_CONSECUTIVE_NO_PROGRESS_PAGES = 1000 +# Safety guard for pathological split re-resolution loops. +_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS = 50 + + +# ----- Hash helpers ------------------------------------------------------ +def _stable_hash_128(payload: bytes) -> str: + """Stable 128-bit hex digest of ``payload``. + + Uses ``MurmurHash3_128`` (the same helper ``partition_key.py`` uses + for EPK routing). The fingerprint is non-cryptographic and used + only for an equality check inside ``_decode_token``: on resume the + SDK recomputes the same hash from the live call's inputs and + raises if it does not match the value baked into the saved token. + A cryptographic hash buys nothing here because the field is never + sent to the service and is never used as proof of input. + + :param payload: Bytes to hash. + :type payload: bytes + :returns: A 32-character hexadecimal digest. + :rtype: str + """ + return murmurhash3_128(bytearray(payload), _UInt128(0, 0)).as_hex() + + +def _hash_query_spec(query: Any) -> str: + """Hash query text + (parameter name, JSON-canonical value) pairs. + + Resume requires the exact same query shape, not a semantically + equivalent one. ``query`` may be either a string or the dict form + produced by ``__CheckAndUnifyQueryFormat``. + + :param query: Query text or query spec dictionary. + :type query: str or dict + :returns: Stable hash for query text and parameters. + :rtype: str + """ + parameters: list = [] + parts: List[bytes] = [] + if isinstance(query, dict): + parts.append((query.get("query") or "").encode("utf-8")) + parameters = query.get("parameters") or [] + else: + parts.append((query or "").encode("utf-8")) + parts.append(b"\0") + for p in parameters: + parts.append((p.get("name", "") or "").encode("utf-8")) + parts.append(b"\0") + parts.append( + json.dumps(p.get("value"), sort_keys=True, separators=(",", ":")).encode("utf-8") + ) + parts.append(b"\0") + return _stable_hash_128(b"".join(parts)) + + +def _hash_feed_range(feed_range: routing_range.Range) -> str: + """Stable 128-bit fingerprint of the INPUT feed_range. + + Detects a token that was created against a different feed_range on + the same container being replayed against the wrong scope. + + The input is first converted to a standard ``[min, max)`` form via + ``Range.to_normalized_range()`` (idempotent — returns ``self`` when + already normalized). The canonical JSON intentionally carries only + ``min`` and ``max``: under the normalized form the inclusivity + flags are constants (``True``/``False``), so hashing them adds no + signal and would only mask the fact that the fingerprint identifies + the *logical EPK interval*, not the on-the-wire representation of + the bounds. Two ``Range`` objects describing the same logical + interval (e.g. ``[A, B)`` and the equivalent ``(A-1, B-1]``) hash + equal. + + :param feed_range: Input feed range. + :type feed_range: ~azure.cosmos._routing.routing_range.Range + :returns: Stable feed range fingerprint. + :rtype: str + """ + normalized = feed_range.to_normalized_range() + canonical = json.dumps( + {"min": normalized.min, "max": normalized.max}, + sort_keys=True, + separators=(",", ":"), + ) + return _stable_hash_128(canonical.encode("utf-8")) + + +# ----- Token codec ------------------------------------------------------- +def _encode_token(payload: dict) -> str: + """JSON-serialize ``payload`` then base64-encode to a single ASCII blob. + + :param payload: Token envelope to serialize. + :type payload: dict + :returns: Base64-encoded token string. + :rtype: str + """ + return base64.b64encode( + json.dumps(payload, separators=(",", ":")).encode("utf-8") + ).decode("ascii") + + +def _decode_token(serialized: Optional[str]) -> Optional[dict]: + """Decode a continuation string into our token dict, or ``None``. + + Returns ``None`` when ``serialized`` is empty or not in our shape. + + Raises ``ValueError`` only when the input parses as our shape but is + structurally invalid (for example unknown ``v`` or missing fields). + + :param serialized: Encoded continuation token from the caller. + :type serialized: Optional[str] + :returns: Decoded token payload when valid; otherwise ``None``. + :rtype: Optional[dict] + """ + if not serialized: + return None + try: + decoded_bytes = base64.b64decode(serialized, validate=True) + decoded = json.loads(decoded_bytes.decode("utf-8")) + except (ValueError, TypeError, UnicodeDecodeError, binascii.Error): + return None # not our shape -> start fresh + if not isinstance(decoded, dict) or _FIELD_VERSION not in decoded: + return None + version = decoded.get(_FIELD_VERSION) + if version != _TOKEN_VERSION: + raise ValueError( + "Unsupported feed_range continuation token version: {}. " + "This SDK supports version {}.".format(version, _TOKEN_VERSION) + ) + _validate_v1_token_structure(decoded) + return decoded + + +def _validate_v1_token_structure(decoded: dict) -> None: + """Validate required v1 token fields so downstream code can index + them without checking for ``KeyError``. + + :param decoded: Decoded token payload to validate. + :type decoded: dict + """ + if not isinstance(decoded.get(_FIELD_COLLECTION_RID), str): + raise ValueError("Malformed feed_range continuation token: 'cr' is required.") + if not isinstance(decoded.get(_FIELD_QUERY_HASH), str): + raise ValueError("Malformed feed_range continuation token: 'qh' is required.") + if not isinstance(decoded.get(_FIELD_FEEDRANGE_HASH), str): + raise ValueError("Malformed feed_range continuation token: 'frh' is required.") + # ``bc`` must be per-entry inside ``c[i]``; top-level ``bc`` is invalid. + if _FIELD_BACKEND_CONTINUATION in decoded: + raise ValueError( + "Malformed feed_range continuation token: top-level 'bc' is not " + "supported; 'bc' must live inside each 'c' entry." + ) + + entries = decoded.get(_FIELD_CONTINUATIONS) + if not isinstance(entries, list) or not entries: + # Producers clear the continuation header when drained, so + # an on-wire token must contain at least one entry. + raise ValueError( + "Malformed feed_range continuation token: '{}' is required and " + "must be a non-empty list.".format(_FIELD_CONTINUATIONS) + ) + for idx, entry in enumerate(entries): + if not isinstance(entry, dict): + raise ValueError( + "Malformed feed_range continuation token: '{}[{}]' must be an object.".format( + _FIELD_CONTINUATIONS, idx + ) + ) + _validate_range_dict(entry, "{}[{}]".format(_FIELD_CONTINUATIONS, idx)) + + +def _validate_range_dict(range_dict: dict, field_name: str) -> None: + """Each persisted feedrange is a {'min': str, 'max': str, 'bc': str|null} dict. + + :param range_dict: Serialized feed range dictionary. + :type range_dict: dict + :param field_name: Field label used in validation messages. + :type field_name: str + """ + if not isinstance(range_dict.get("min"), str) or not isinstance(range_dict.get("max"), str): + raise ValueError( + "Malformed feed_range continuation token: '{}' and '{}' are required.".format( + f"{field_name}.min", f"{field_name}.max" + ) + ) + if _FIELD_BACKEND_CONTINUATION not in range_dict: + raise ValueError( + "Malformed feed_range continuation token: '{}.bc' is required (use null when absent).".format( + field_name + ) + ) + bc_value = range_dict[_FIELD_BACKEND_CONTINUATION] + if bc_value is not None and not isinstance(bc_value, str): + raise ValueError( + "Malformed feed_range continuation token: '{}.bc' must be a string or null.".format( + field_name + ) + ) + + +# ----- Feedrange / routing helpers --------------------------------------- +def _dict_to_range(range_dict: dict) -> routing_range.Range: + """Convert a persisted ``{'min': ..., 'max': ...}`` dict back into a ``Range``. + + :param range_dict: Persisted feed range dictionary. + :type range_dict: dict + :returns: Routing range instance. + :rtype: ~azure.cosmos._routing.routing_range.Range + """ + return routing_range.Range( + range_min=range_dict["min"], + range_max=range_dict["max"], + isMinInclusive=True, + isMaxInclusive=False, + ) + + +def _validate_token_identity( + inbound: dict, + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, +) -> None: + """Confirm the inbound token was created for the same collection, + query, and feed_range the current call is using. If any of the + three fingerprints disagrees, raise ``ValueError`` so the caller + finds out instead of silently getting rows from a different + request. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :param resource_id: Current collection resource ID. + :type resource_id: str + :param query: Current query spec. + :type query: str or dict + :param feed_range_epk: Current feed range scope. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + """ + expected_qh = _hash_query_spec(query) + expected_frh = _hash_feed_range(feed_range_epk) + if inbound[_FIELD_COLLECTION_RID] != resource_id: + raise ValueError( + "Continuation token was created for a different collection " + "(collection rid mismatch)." + ) + if inbound[_FIELD_QUERY_HASH] != expected_qh: + raise ValueError( + "Continuation token was created with a different query " + "(query hash mismatch). Resume requires the exact same query shape." + ) + if inbound[_FIELD_FEEDRANGE_HASH] != expected_frh: + raise ValueError( + "Continuation token was created for a different feed_range " + "(feed_range hash mismatch)." + ) + + +def _should_bridge_legacy_continuation( + inbound_serialized_continuation: Optional[str], + inbound_token_payload: Optional[dict], + is_full_pk_structured_scope: bool, + is_single_partition_scope: bool, +) -> bool: + """Whether to bridge an inbound legacy continuation into pagination state. + + We bridge only when the inbound continuation exists, did not decode as + structured ``v=1`` (legacy/opaque token), and the current request scope can + be represented safely by a single legacy continuation slot: + + * full-PK structured scope (always single partition), or + * non-full-PK scope that currently maps to one physical partition. + + :param inbound_serialized_continuation: The raw inbound continuation token, if any. + :type inbound_serialized_continuation: str or None + :param inbound_token_payload: The decoded structured (``v=1``) token payload, or + ``None`` if the inbound token is legacy/opaque. + :type inbound_token_payload: dict or None + :param bool is_full_pk_structured_scope: True if the request scope is a full + partition key (always single partition). + :param bool is_single_partition_scope: True if the request scope currently maps + to a single physical partition. + :returns: True if the legacy continuation should be bridged into pagination state. + :rtype: bool + """ + return bool( + inbound_serialized_continuation + and inbound_token_payload is None + and (is_full_pk_structured_scope or is_single_partition_scope) + ) + + +def _extract_resume_queue( + inbound: dict, +) -> List[Tuple[routing_range.Range, Optional[str]]]: + """Decode the ``c`` list into an ordered list of ``(range, bc)`` pairs. + + The wire format stores a single ordered ``c`` list of + ``{min, max, bc}`` entries. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :returns: Ordered list of ``(range, backend_continuation)`` pairs. + :rtype: list[tuple[~azure.cosmos._routing.routing_range.Range, Optional[str]]] + """ + return [ + (_dict_to_range(entry), entry.get(_FIELD_BACKEND_CONTINUATION)) + for entry in inbound[_FIELD_CONTINUATIONS] + ] + + +def _build_scope_from_overlaps( + overlapping: List[dict], feedrange: routing_range.Range +) -> Tuple[List[dict], routing_range.Range]: + """Compute the smallest EPK ``Range`` that covers every one of the + overlapping physical partitions, and return both the original + overlaps and that combined range. + + Both the sync and async pagination paths call this directly after + awaiting / invoking ``routing_map_provider.get_overlapping_ranges`` + themselves, so the live lookup stays at the call site (sync vs. + async) and the pure combine logic is shared here. + + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :param feedrange: Feed range used for error context. + :type feedrange: ~azure.cosmos._routing.routing_range.Range + :returns: Original overlaps and the combined range covering them. + :rtype: tuple[list[dict], ~azure.cosmos._routing.routing_range.Range] + """ + if not overlapping: + raise RuntimeError( + "Routing map returned no overlapping ranges for feedrange " + "[{}, {}).".format(feedrange.min, feedrange.max) + ) + min_inclusive = overlapping[0]["minInclusive"] + max_exclusive = overlapping[0]["maxExclusive"] + for overlap_range in overlapping[1:]: + if overlap_range["minInclusive"] < min_inclusive: + min_inclusive = overlap_range["minInclusive"] + if overlap_range["maxExclusive"] > max_exclusive: + max_exclusive = overlap_range["maxExclusive"] + scope = routing_range.Range( + range_min=min_inclusive, + range_max=max_exclusive, + isMinInclusive=True, + isMaxInclusive=False, + ) + return overlapping, scope + + +def _derive_initial_feedranges( + feed_range_epk: routing_range.Range, overlapping: List[dict] +) -> List[routing_range.Range]: + """Given the caller's input feed_range and the partitions it + currently overlaps, return one sub-feedrange per partition (the + intersection of the partition's range and the input feed_range), + ordered by EPK ``min``. + + :param feed_range_epk: Requested feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :returns: Derived feed ranges ordered by ``min``. + :rtype: list[~azure.cosmos._routing.routing_range.Range] + """ + feedranges: List[routing_range.Range] = [] + for overlap_range in overlapping: + partition_range = routing_range.Range.PartitionKeyRangeToRange(overlap_range) + feedranges.append( + routing_range.Range( + range_min=max(partition_range.min, feed_range_epk.min), + range_max=min(partition_range.max, feed_range_epk.max), + isMinInclusive=True, + isMaxInclusive=False, + ) + ) + feedranges.sort(key=lambda feedrange_range: feedrange_range.min) + return feedranges + + +class _FeedRangePaginationState: + """Tracks where a feed_range query is up to between page calls. + + Holds a single ordered queue of ``(sub-range, backend continuation)`` + pairs. The pagination loop: + + * peeks the queue head to learn the next sub-range to POST and + the backend continuation (if any) to send with it, + * updates the head's backend continuation when the backend + returns a non-null one, + * pops the head when the sub-range is drained, + * on a partition split, replaces the head with one entry per + child sub-range (each inheriting the parent's backend continuation). + + There is no separate "current vs. remaining" split. The head is + ``queue[0]`` and later entries are queued behind it. + + Split-child insertion is tail-based so existing queued ranges + remain ahead of newly discovered children. + + Not thread-safe. One instance is created per ``query_items`` call + and is mutated only by that call's pagination loop (sync or async) + — never shared across threads or concurrent tasks. + """ + + def __init__( + self, + queue: Iterable[Tuple[routing_range.Range, Optional[str]]], + page_size_hint: Optional[int], + ) -> None: + self.queue: Deque[Tuple[routing_range.Range, Optional[str]]] = deque(queue) + self.page_size_hint = page_size_hint + + @classmethod + def from_inbound( + cls, + inbound: dict, + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state from a decoded inbound token. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized for resume. + :rtype: _FeedRangePaginationState + """ + return cls(_extract_resume_queue(inbound), page_size_hint) + + @classmethod + def from_derived_feedranges( + cls, + feedranges: Iterable[routing_range.Range], + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state from feedranges computed at startup (no backend + continuations yet — every entry starts with ``bc = None``). + + :param feedranges: Derived feedranges ordered by ``min``. + :type feedranges: Iterable[~azure.cosmos._routing.routing_range.Range] + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized for first request. + :rtype: _FeedRangePaginationState + """ + return cls(((fr, None) for fr in feedranges), page_size_hint) + + @classmethod + def from_single_feedrange_with_continuation( + cls, + feedrange: routing_range.Range, + backend_continuation: Optional[str], + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state for one feedrange where a backend continuation + already exists. + + Used for legacy-token compatibility on full-PK queries: + we keep the decoder strict, then bridge a legacy continuation + string into the queue head's ``bc`` slot for the single target + range. + + :param feedrange: Single feedrange to seed. + :type feedrange: ~azure.cosmos._routing.routing_range.Range + :param backend_continuation: Existing backend continuation for the range. + :type backend_continuation: Optional[str] + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized with one queued entry. + :rtype: _FeedRangePaginationState + """ + return cls(((feedrange, backend_continuation),), page_size_hint) + + @property + def head_range(self) -> Optional[routing_range.Range]: + """The sub-range at the head of the queue (the one the next + backend POST will target), or ``None`` when the queue is drained. + """ + return self.queue[0][0] if self.queue else None + + @property + def head_bc(self) -> Optional[str]: + """Backend continuation paired with the head sub-range, or + ``None`` if the head has not been started yet (or has nothing + more to fetch). + """ + return self.queue[0][1] if self.queue else None + + def can_issue_request(self) -> bool: + """Whether another backend POST can be issued for this page. + + :returns: ``True`` when the queue is non-empty. + :rtype: bool + """ + return bool(self.queue) + + def explode_on_multi_overlap(self, overlapping: List[dict]) -> bool: + """If the head sub-range now spans more than one physical + partition (Cosmos split it since the token was minted), + replace the head with one entry per child sub-range and carry + the parent backend continuation onto each child. + + Dequeue the parent and append child entries at the tail + (preserving child EPK order). Each child inherits the + parent ``bc`` so resume can continue after a split without + replaying the entire child feed range. + + :param overlapping: Routing overlaps for the head sub-range. + :type overlapping: list[dict] + :returns: ``True`` when the head was split into multiple children. + :rtype: bool + """ + if not self.queue or len(overlapping) <= 1: + return False + head_range, parent_bc = self.queue[0] + sub_feedranges = _derive_initial_feedranges(head_range, overlapping) + if not sub_feedranges: + return False + self.queue.popleft() + # Keep existing tail entries ahead of split children. + for sub in sub_feedranges: + self.queue.append((sub, parent_bc)) + return True + + def apply_post_result(self, items_returned: int, backend_continuation: Optional[str]) -> None: + """Apply one backend response to the queue. + + :param items_returned: Number of logical rows returned by this POST. + :type items_returned: int + :param backend_continuation: Backend continuation for the head + sub-range (``None`` when the head is drained). + :type backend_continuation: Optional[str] + """ + # Kept for call-site API symmetry and observability; page-size hints are + # no longer decremented between backend requests. + _ = items_returned + if not self.queue: + return + head_range, _ = self.queue[0] + if backend_continuation is not None: + # Update head's bc in place; head sub-range itself is unchanged. + self.queue[0] = (head_range, backend_continuation) + else: + # Head sub-range fully drained; advance to next entry. + self.queue.popleft() + + def write_outbound_continuation( + self, + last_response_headers: MutableMapping[str, Any], + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + ) -> None: + """Set or clear the outbound continuation header from the queue. + + Empty queue means the pagination loop ran out of sub-ranges; the + header is removed and the caller's ``by_page`` loop terminates. + Otherwise the entire queue is serialized as a fresh v=1 envelope + via ``_build_outbound_token``. + + :param last_response_headers: Response headers to mutate. + :type last_response_headers: MutableMapping[str, Any] + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query spec used for hashing. + :type query: str or dict + :param feed_range_epk: Original request feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + """ + if not self.queue: + last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) + return + last_response_headers[http_constants.HttpHeaders.Continuation] = _build_outbound_token( + resource_id, + query, + feed_range_epk, + self.queue, + ) + + +def _write_query_outbound_continuation( + last_response_headers: MutableMapping[str, Any], + pagination_state: _FeedRangePaginationState, + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + is_full_pk_structured_scope: bool, + emit_legacy_for_single_partition: bool, +) -> None: + """Write outbound continuation for feed-range pagination. + + Full-PK queries always emit the legacy single-string continuation + so persisted bookmarks remain readable by older SDK versions. + Feed-range/prefix queries emit legacy continuation when the caller's + input scope currently maps to a single physical partition; otherwise + they emit the structured envelope. + + :param last_response_headers: Response headers to mutate. + :type last_response_headers: MutableMapping[str, Any] + :param pagination_state: Current pagination state for this request. + :type pagination_state: _FeedRangePaginationState + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query text/spec used for hash identity. + :type query: Any + :param feed_range_epk: Original request feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param is_full_pk_structured_scope: Whether request scope is full-PK on structured path. + :type is_full_pk_structured_scope: bool + :param emit_legacy_for_single_partition: Whether non-full-PK scope currently maps to a + single physical partition and can safely emit legacy continuation. + :type emit_legacy_for_single_partition: bool + :returns: None. Mutates ``last_response_headers`` in place. + :rtype: None + """ + if is_full_pk_structured_scope or emit_legacy_for_single_partition: + legacy_outbound = pagination_state.head_bc + if legacy_outbound is None: + last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) + else: + last_response_headers[http_constants.HttpHeaders.Continuation] = legacy_outbound + return + pagination_state.write_outbound_continuation( + last_response_headers, + resource_id, + query, + feed_range_epk, + ) + + + +def _build_outbound_token( + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + entries: Iterable[Tuple[routing_range.Range, Optional[str]]], +) -> str: + """Build and base64-encode the outbound continuation token from a + queue of ``(range, backend_continuation)`` entries. + + Persists the queue as the wire-format ``c`` list in head-first order. + + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query spec used for hashing. + :type query: str or dict + :param feed_range_epk: Original feed range for the request. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param entries: Ordered ``(range, bc)`` pairs to serialize. + :type entries: Iterable[tuple[~azure.cosmos._routing.routing_range.Range, Optional[str]]] + :returns: Encoded continuation token. + :rtype: str + """ + payload = { + _FIELD_VERSION: _TOKEN_VERSION, + _FIELD_COLLECTION_RID: resource_id, + _FIELD_QUERY_HASH: _hash_query_spec(query), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(feed_range_epk), + _FIELD_CONTINUATIONS: [ + {"min": r.min, "max": r.max, _FIELD_BACKEND_CONTINUATION: bc} + for r, bc in entries + ], + } + return _encode_token(payload) + + +# ----- Pagination-loop helpers shared by sync and async ------------------ +def _normalize_max_item_count(raw_max_item_count: Any) -> Optional[int]: + """Normalize the caller's ``maxItemCount`` to a positive page-size cap or + ``None`` (unbounded). + + Three rules, applied in order: + * ``None`` (caller did not set one) -> ``None`` (unbounded; backend + decides page size). + * Non-numeric values (e.g. a malformed string) -> ``None``. Raising + here would change the error surface for callers that previously + worked by accident; ``None`` keeps them working. + * Any value ``<= 0`` -> ``None``. A zero or negative cap would make + the pagination loop emit a continuation token without issuing any + backend POST, which can produce an empty-page-with-continuation + cycle on the caller side. + + :param raw_max_item_count: Raw ``maxItemCount`` value from options. + :type raw_max_item_count: Any + """ + if raw_max_item_count is None: + return None + try: + normalized = int(raw_max_item_count) + except (TypeError, ValueError): + return None + if normalized <= 0: + return None + return normalized + + +def _count_page_items_from_partial_result( + partial_result: Optional[dict[str, Any]], + query: Any, +) -> int: + """Return how many logical items should consume the remaining page-item count. + + Aggregate partial rows are merge-input fragments, not final logical + rows, so they should not consume page items and force an early break. + + :param partial_result: One backend POST result. + :type partial_result: Optional[dict[str, Any]] + :param query: Query text or query spec dictionary. + :type query: Any + :returns: Number of items to subtract from the remaining page-item count. + :rtype: int + """ + if not partial_result: + return 0 + docs = partial_result.get("Documents") + if not isinstance(docs, list): + return 0 + if len(docs) != 1: + # Cosmos backend invariant: aggregate partial fragments are emitted as + # single-element arrays. Non-singleton arrays are treated as regular rows. + return len(docs) + + # Aggregate partials must be merged across overlaps before they count as rows. + if _classify_aggregate_partial(docs, query) != _AggregatePartialClassification.NONE: + return 0 + return 1 + + +def _update_no_progress_page_count( + current_no_progress_count: int, + page_items_returned: int, + previous_feedrange: Optional[routing_range.Range], + previous_backend_continuation: Optional[str], + head_feedrange: Optional[routing_range.Range], + head_backend_continuation: Optional[str], +) -> int: + """Track consecutive empty pages that still carry continuation. + + :param current_no_progress_count: Current consecutive no-progress page count. + :type current_no_progress_count: int + :param page_items_returned: Number of logical page items returned this iteration. + :type page_items_returned: int + :param previous_feedrange: Feedrange before processing this response. + :type previous_feedrange: Optional[~azure.cosmos._routing.routing_range.Range] + :param previous_backend_continuation: Backend continuation before response. + :type previous_backend_continuation: Optional[str] + :param head_feedrange: Feedrange after processing this response. + :type head_feedrange: Optional[~azure.cosmos._routing.routing_range.Range] + :param head_backend_continuation: Backend continuation after response. + :type head_backend_continuation: Optional[str] + :returns: Updated consecutive no-progress page count. + :rtype: int + """ + def _range_bounds(rng: Optional[routing_range.Range]) -> Optional[Tuple[str, str]]: + if rng is None: + return None + return rng.min, rng.max + + if page_items_returned > 0: + return 0 + if head_backend_continuation is None: + return 0 + if _range_bounds(head_feedrange) != _range_bounds(previous_feedrange): + return 0 + if head_backend_continuation != previous_backend_continuation: + return 0 + + # No logical rows and no cursor/feedrange movement: caller made no progress. + return current_no_progress_count + 1 + + +def _increment_explode_iterations_or_raise(current_explode_iterations: int) -> int: + """Increment split re-resolution iteration count or raise on overflow. + + :param current_explode_iterations: Current explode-loop iteration count. + :type current_explode_iterations: int + :returns: Incremented explode-loop iteration count. + :rtype: int + """ + updated_iterations = current_explode_iterations + 1 + if updated_iterations > _MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS: + raise RuntimeError( + "Exceeded {} split re-resolution iterations while expanding overlapping " + "feed ranges. This indicates a stale/corrupted routing map response." + .format(_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS) + ) + return updated_iterations + + +def _apply_feedrange_request_headers( + req_headers: MutableMapping[str, Any], + overlapping: List[dict], + partition_scope: routing_range.Range, + head_feedrange: routing_range.Range, + page_size_hint: Optional[int], + inbound_continuation: Optional[str], +) -> None: + """Populate ``req_headers`` for one backend POST against + ``head_feedrange`` and the partition currently serving it. + + Routes by ``PartitionKeyRangeID`` and only adds the EPK filter + headers when the current feed range is a strict sub-range of the + partition. Page size and continuation are explicitly set or + cleared so leftover state from the previous iteration cannot leak. + + :param req_headers: Mutable request headers to populate. + :type req_headers: MutableMapping[str, Any] + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :param partition_scope: Union scope for overlapping partitions. + :type partition_scope: ~azure.cosmos._routing.routing_range.Range + :param head_feedrange: Feed range for the current backend request. + :type head_feedrange: ~azure.cosmos._routing.routing_range.Range + :param page_size_hint: Request page-size hint for the backend POST. + :type page_size_hint: Optional[int] + :param inbound_continuation: Continuation token for backend request. + :type inbound_continuation: Optional[str] + """ + pkr_id = overlapping[0]["id"] + req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = pkr_id + + is_full_partition = ( + len(overlapping) == 1 + and head_feedrange.min == partition_scope.min + and head_feedrange.max == partition_scope.max + ) + if is_full_partition: + req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) + req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) + else: + req_headers[http_constants.HttpHeaders.StartEpkString] = head_feedrange.min + req_headers[http_constants.HttpHeaders.EndEpkString] = head_feedrange.max + req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" + + if page_size_hint is not None: + req_headers[http_constants.HttpHeaders.PageSize] = str(page_size_hint) + else: + req_headers.pop(http_constants.HttpHeaders.PageSize, None) + + if inbound_continuation is not None: + req_headers[http_constants.HttpHeaders.Continuation] = inbound_continuation + else: + req_headers.pop(http_constants.HttpHeaders.Continuation, None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index af83f50bae55..d581f6fc3808 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.14.6" +VERSION = "4.14.7" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 5c8a3354061d..3fd1541e21ca 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -23,10 +23,11 @@ """Document client class for the Azure Cosmos database service. """ +import logging import os from urllib.parse import urlparse import uuid -from typing import Callable, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast from typing_extensions import TypedDict from urllib3.util.retry import Retry @@ -54,7 +55,21 @@ from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState from .._change_feed.feed_range_internal import FeedRangeInternalEpk -from .._routing import routing_range +from .._routing.feed_range_continuation import ( + _FeedRangePaginationState, + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _apply_feedrange_request_headers, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _increment_explode_iterations_or_raise, + _normalize_max_item_count, + _should_bridge_legacy_continuation, + _update_no_progress_page_count, + _validate_token_identity, + _write_query_outbound_continuation, +) from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants from .._cosmos_responses import CosmosDict, CosmosList @@ -80,6 +95,8 @@ from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from .._range_partition_resolver import RangePartitionResolver +_LOGGER = logging.getLogger(__name__) + @@ -2951,6 +2968,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, partition_key_range_id: Optional[str] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, is_query_plan: bool = False, + response_headers_list: Optional[list[CaseInsensitiveDict]] = None, **kwargs: Any ) -> list[dict[str, Any]]: """Query for more than one Azure Cosmos resources. @@ -2969,6 +2987,9 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, :type response_hook: Callable[[Mapping[str, Any], dict[str, Any]], None] :param bool is_query_plan: Specifies if the call is to fetch query plan + :param response_headers_list: + Optional list to which per-request response headers will be appended. + :type response_headers_list: Optional[list[CaseInsensitiveDict]] :returns: A list of the queried resources. :rtype: list :raises SystemError: If the query compatibility mode is undefined. @@ -2991,6 +3012,31 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, # we need to set operation_state in kwargs as that's where it is looked at while sending the request kwargs.setdefault("timeout", timeout) + # The capture dict can arrive via two upstream paths: + # 1. The query execution context puts it into ``options`` (the + # common case for query pagination — see the async + # ``_QueryExecutionContextBase._fetch_items_helper_no_retries``). + # 2. ``routing_map_provider.get_routing_map`` puts it into + # ``kwargs`` for PK-range fetches. + # Honour both so checkpoint-on-failure works on every path. + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + if internal_headers_capture is None and isinstance(options, dict): + internal_headers_capture = options.pop( + "_internal_response_headers_capture", None + ) + + def _capture_internal_headers(headers: Mapping[str, Any]) -> None: + # `internal_headers_capture` is Optional[Dict]; checking it + # for None once inside this helper lets the type checker + # treat it as a plain Dict for the .clear()/.update() calls + # below, and keeps every call site to a single line. + if internal_headers_capture is None: + return + internal_headers_capture.clear() + internal_headers_capture.update(headers) + if query: __GetBodiesFromQueryResult = result_fn else: @@ -3034,9 +3080,13 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: feed_options) request_params.headers = headers - result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(headers, result, last_response_headers) + result, get_response_headers = await self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = get_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(get_response_headers) + self._UpdateSessionIfRequired(headers, result, get_response_headers) + if response_headers_list is not None: + response_headers_list.append(get_response_headers.copy()) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) @@ -3066,88 +3116,339 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: await base.set_session_token_header_async(self, req_headers, path, request_params, options, partition_key_range_id) - # Check if the over lapping ranges can be populated + # Check if the overlapping ranges can be populated feed_range_epk = None + is_full_pk_structured_scope = False if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() elif options.get("partitionKey") is not None and container_property is not None: - # check if query has prefix partition key partition_key_value = options["partitionKey"] partition_key_obj = _build_partition_key_from_properties(container_property) if partition_key_obj._is_prefix_partition_key(partition_key_value): req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) partition_key_value = cast(_SequentialPartitionKeyType, partition_key_value) feed_range_epk = partition_key_obj._get_epk_range_for_prefix_partition_key(partition_key_value) + else: + req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) + # Full-PK returns a single-value inclusive range; normalize to + # [min, max) before routing-map overlap resolution. + feed_range_epk = partition_key_obj._get_epk_range_for_partition_key( + partition_key_value + ).to_normalized_range() + is_full_pk_structured_scope = True if feed_range_epk is not None: - over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feed_range_epk], - options) - results: dict[str, Any] = {} - # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over - # lapping physical partition. The EPK sub range will be one of four: - # 1) Will have a range min equal to the feed range EPK min, and a range max equal to the over lapping - # partition - # 2) Will have a range min equal to the over lapping partition range min, and a range max equal to the - # feed range EPK range max. - # 3) will match exactly with the current over lapping physical partition, so we just return the over lapping - # physical partition's partition key id. - # 4) Will equal the feed range EPK since it is a sub range of a single physical partition - for over_lapping_range in over_lapping_ranges: - single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) - # Since the range min and max are all Upper Cased string Hex Values, - # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), - range_max=min(single_range.max, feed_range_epk.max), - isMinInclusive=True, isMaxInclusive=False) - if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: - # The Epk Sub Range spans exactly one physical partition - # In this case we can route to the physical pk range id - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] + if id_ is None: + raise ValueError("resource_id is required for feed_range continuation.") + # The None-check above already narrows ``id_`` to ``str`` for the + # rest of this block. Bind it to a clearly-named local so the + # feed_range helpers below read as ``resource_id_str`` instead + # of the generic ``id_``. + resource_id_str: str = id_ + # Decode and validate inbound continuation for this request. + # ``None`` means start from the beginning of the requested + # feed range. + page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) + # Single shared copy of options for routing-map lookups in this call. + # ``get_overlapping_ranges`` does not mutate options; copying once + # avoids per-iteration ``dict(options)`` allocations. + routing_options = dict(options) + inbound_serialized_continuation = options.get("continuation") + inbound_token_payload = _decode_token(inbound_serialized_continuation) + legacy_bridge_in_use = False + # Cache for the input scope's single-partition classification. + # We compute it at most once per __QueryFeed call so inbound + # bridge-detection, mid-page checkpoint, and end-of-page outbound + # writer all agree even if the PK range cache refreshes mid-call. + cached_is_single_partition: Optional[bool] = None + + async def _is_input_scope_single_partition() -> bool: + """Return True when the caller input range currently maps to one physical partition. + + Result is cached for the duration of this __QueryFeed call. + + :returns: True if the input scope maps to a single physical partition. + :rtype: bool + """ + nonlocal cached_is_single_partition + if cached_is_single_partition is None: + scope_overlaps = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [feed_range_epk], routing_options + ) + cached_is_single_partition = ( + len(_derive_initial_feedranges(feed_range_epk, scope_overlaps)) == 1 + ) + return cached_is_single_partition + + if inbound_serialized_continuation and inbound_token_payload is None: + scope_is_single_partition = False + if not is_full_pk_structured_scope: + scope_is_single_partition = await _is_input_scope_single_partition() + if _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + scope_is_single_partition, + ): + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + else: + _LOGGER.warning( + "Feed-range query continuation token is in legacy format and the input scope " + "currently maps to one physical partition; honoring legacy continuation for resume." + ) + legacy_bridge_in_use = True else: - # The Epk Sub Range spans less than a single physical partition - # In this case we route to the physical partition and - # pass the epk sub range to the headers to filter within partition - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] - req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min - req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max - req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, last_response_headers = await self.__Post( - path, - request_params, + _LOGGER.warning( + "Feed-range query continuation token is not in the supported structured format; " + "restarting this feed_range query from the beginning." + ) + if inbound_token_payload is not None: + _validate_token_identity( + inbound_token_payload, + resource_id_str, query, - req_headers, - **kwargs + feed_range_epk, ) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) - - # Introducing a temporary complex function into a critical path to handle aggregated queries, - # during splits as a precaution falling back to the original logic if anything goes wrong - try: - results = base._merge_query_results(results, partial_result, query) - except Exception: # pylint: disable=broad-exception-caught - # If the new merge logic fails, fall back to the original logic. - if results: - results["Documents"].extend(partial_result["Documents"]) - else: - results = partial_result + pagination_state = _FeedRangePaginationState.from_inbound( + inbound_token_payload, page_size_hint + ) + elif legacy_bridge_in_use and inbound_serialized_continuation: + pagination_state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + feed_range_epk, + inbound_serialized_continuation, + page_size_hint, + ) + else: + first_overlaps = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [feed_range_epk], routing_options + ) + all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) + if not all_feedranges: + # The input feed_range overlaps no current physical + # partition. Fall through to the regular __Post path + # below so cross-partition gating can still surface + # (e.g. an empty ``partition_key=[]`` prefix with + # ``enableCrossPartitionQuery=False`` must raise + # ``BAD_REQUEST``). + pagination_state = None + else: + pagination_state = _FeedRangePaginationState.from_derived_feedranges( + all_feedranges, + page_size_hint, + ) + + if pagination_state is not None: + results: dict[str, Any] = {} + feedrange_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + consecutive_no_progress_pages = 0 + + async def _checkpoint_and_reraise(error: Exception) -> None: + # Intentionally broad: stamp the latest resumable checkpoint + # for any mid-page failure, then re-raise the original error. + self.last_response_headers = feedrange_response_headers + try: + single_partition_scope_for_outbound = ( + (not is_full_pk_structured_scope) and (await _is_input_scope_single_partition()) + ) + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + single_partition_scope_for_outbound, + ) + except Exception as continuation_write_error: # pylint: disable=broad-exception-caught + _LOGGER.warning( + "Failed to write continuation while handling query POST failure: %s", + continuation_write_error, + ) + raise error + + # NOTE: Keep this feed_range pagination loop in sync with + # ``azure/cosmos/_cosmos_client_connection.py::__QueryFeed``. + while pagination_state.can_issue_request(): + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + + # Wrap all mid-page work that can raise (routing lookups, + # scope build/explode, backend POST, and result merge) so + # we always stamp a resumable checkpoint into + # last_response_headers[Continuation] before re-raising. + # Post-result accounting below is pure local bookkeeping + # and is intentionally left outside this try. + try: + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. + overlapping = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + await base.set_session_token_header_async( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) + + backend_query_result, backend_response_headers = await self.__Post( + path, + request_params, + query, + req_headers, + **kwargs + ) + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, + ) + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + except Exception as mid_page_error: # pylint: disable=broad-exception-caught + await _checkpoint_and_reraise(mid_page_error) + + previous_feedrange = pagination_state.head_range + previous_backend_continuation = pagination_state.head_bc + page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + pagination_state.apply_post_result( + page_items_returned, + backend_response_headers.get(http_constants.HttpHeaders.Continuation), + ) + consecutive_no_progress_pages = _update_no_progress_page_count( + consecutive_no_progress_pages, + page_items_returned, + previous_feedrange, + previous_backend_continuation, + pagination_state.head_range, + pagination_state.head_bc, + ) + if ( + consecutive_no_progress_pages >= _MAX_CONSECUTIVE_NO_PROGRESS_PAGES + and consecutive_no_progress_pages % _MAX_CONSECUTIVE_NO_PROGRESS_PAGES == 0 + ): + # Warning-only: do not fail fast here. + current_head = pagination_state.head_range + head_min = current_head.min if current_head else "" + head_max = current_head.max if current_head else "" + _LOGGER.warning( + "Feed-range query has returned 0 items for %s consecutive continuation pages " + "with the same continuation token and partition key range [%s, %s); continuing scan.", + consecutive_no_progress_pages, + head_min, + head_max, + ) + + # maxItemCount is a per-request hint. Return this SDK page + # after the first non-empty logical result instead of filling + # an exact target count by issuing extra backend requests. + if page_items_returned > 0: + break + + # Pagination loop is done — write the final outbound + # continuation (or clear the header if the queue is fully + # drained) so the caller's ``by_page`` loop terminates. + single_partition_scope_for_outbound = ( + (not is_full_pk_structured_scope) and (await _is_input_scope_single_partition()) + ) + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + single_partition_scope_for_outbound, + ) + # End feed_range pagination block. + self.last_response_headers = feedrange_response_headers + # if the prefix partition query has results lets return it + if results: + if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + index_metrics_raw = self.last_response_headers[http_constants.HttpHeaders.IndexUtilization] + self.last_response_headers[http_constants.HttpHeaders.IndexUtilization] = ( + _utils.get_index_metrics_info(index_metrics_raw)) + return __GetBodiesFromQueryResult(results) + return [] - if response_hook: - response_hook(self.last_response_headers, partial_result) - # if the prefix partition query has results lets return it - if results: - return __GetBodiesFromQueryResult(results) + result, post_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = post_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(post_response_headers) - result, last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) - self.last_response_headers = last_response_headers # update session for request mutates data on server side - self._UpdateSessionIfRequired(req_headers, result, last_response_headers) + self._UpdateSessionIfRequired(req_headers, result, post_response_headers) # TODO: this part might become an issue since HTTP/2 can return read-only headers if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = self.last_response_headers[INDEX_METRICS_HEADER] self.last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + if response_headers_list is not None: + response_headers_list.append(post_response_headers.copy()) if response_hook: response_hook(self.last_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 9cb5578a8e78..b46e9a1a340f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -917,6 +917,7 @@ def query_items( # pylint:disable=docstring-missing-param # Get container property and init client container caches container_properties = self._get_properties_with_options(feed_options) + kwargs["container_properties"] = container_properties # Update 'feed_options' from 'kwargs' if utils.valid_key_value_exist(kwargs, "enable_cross_partition_query"): diff --git a/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py b/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py index c41fcd52567f..e6bc283de0e4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py +++ b/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py @@ -167,15 +167,29 @@ def test_replace_with_new_computed_properties(self): # Check that computed properties were properly sent to replaced container self.assertListEqual(new_computed_properties, replaced_collection.read()["computedProperties"]) - # Test 1: Test first computed property - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', - partition_key="test")) + # Test 1: Test first computed property. Allow brief settle time after replace + # for the new computed-property index to be ready. + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', + partition_key="test")) + if len(queried_items) == 3: + break + time.sleep(1) self.assertEqual(len(queried_items), 3) - # Test 1 Negative: Test if using non-existent computed property name returns nothing - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', partition_key="test")) + # Test 1 Negative: Test if using non-existent computed property name returns nothing. + # After a container replace the backend may take a moment to drop the old + # computed property index ("cp_lower" from before the replace), so retry briefly. + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', + partition_key="test")) + if len(queried_items) == 0: + break + time.sleep(1) self.assertEqual(len(queried_items), 0) # Test 2: Test Second Computed Property @@ -184,8 +198,14 @@ def test_replace_with_new_computed_properties(self): self.assertEqual(len(queried_items), 2) # Test 2 Negative: Test Str length using old computed properties name - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', partition_key="test")) + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', + partition_key="test")) + if len(queried_items) == 0: + break + time.sleep(1) self.assertEqual(len(queried_items), 0) self.created_db.delete_container(created_collection.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py index 3b4a35a97318..1f48b87558a0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio import unittest import uuid import pytest @@ -181,16 +182,30 @@ async def test_replace_with_new_computed_properties_async(self): container = await replaced_collection.read() assert new_computed_properties == container["computedProperties"] + # Give the backend a brief chance to complete computed-property indexing after replace. + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', + partition_key="test")] + if len(queried_items) == 3: + break + await asyncio.sleep(1) + # Test 1: Test first computed property - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', - partition_key="test")] self.assertEqual(len(queried_items), 3) - # Test 1 Negative: Test if using non-existent computed property name returns nothing - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', - partition_key="test")] + # Test 1 Negative: Test if using non-existent computed property name returns nothing. + # After a container replace the backend may take a moment to drop the old + # computed property index ("cp_lower" from before the replace), so retry briefly. + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', + partition_key="test")] + if len(queried_items) == 0: + break + await asyncio.sleep(1) self.assertEqual(len(queried_items), 0) # Test 2: Test Second Computed Property @@ -200,9 +215,14 @@ async def test_replace_with_new_computed_properties_async(self): self.assertEqual(len(queried_items), 2) # Test 2 Negative: Test Str length using old computed properties name - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', - partition_key="test")] + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', + partition_key="test")] + if len(queried_items) == 0: + break + await asyncio.sleep(1) self.assertEqual(len(queried_items), 0) await self.created_db.delete_container(created_collection.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py index a7bd88a35253..c2dc81fbf01a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py @@ -553,6 +553,115 @@ def test_partitioned_collection_prefix_partition_query_subpartition(self): self.assertTrue("Cross partition query is required but disabled" in error.message) + def test_partitioned_collection_full_partition_key_pagination_resume_subpartition(self): + created_db = self.databaseForTest + collection_id = 'test_partitioned_collection_full_partition_key_resume_MH ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + partition_key = ['CA', 'Oxnard', '93033'] + total_items = 35 + for i in range(total_items): + created_collection.create_item( + body={ + 'id': 'full-pk-doc-{0:03d}'.format(i), + 'state': partition_key[0], + 'city': partition_key[1], + 'zipcode': partition_key[2] + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=partition_key, + max_item_count=10 + ) + pager = query_iterable.by_page() + first_page = list(pager.next()) + self.assertGreater(len(first_page), 0) + token = pager.continuation_token + self.assertIsNotNone(token) + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resume_pager = query_iterable.by_page(token) + resumed_remaining_ids = [] + for page in resume_pager: + resumed_remaining_ids.extend(item['id'] for item in page) + + self.assertListEqual(expected_remaining_ids, resumed_remaining_ids) + + baseline_ids = [ + item['id'] for item in created_collection.query_items(query=query, partition_key=partition_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + self.assertListEqual(baseline_ids, fetched_ids) + + created_db.delete_container(created_collection.id) + + def test_partitioned_collection_prefix_partition_key_pagination_resume_subpartition(self): + created_db = self.databaseForTest + collection_id = 'test_partitioned_collection_prefix_partition_key_resume_MH ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + for i in range(30): + created_collection.create_item( + body={ + 'id': 'ca-doc-{0:03d}'.format(i), + 'state': 'CA', + 'city': 'city-{0}'.format(i % 5), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + for i in range(5): + created_collection.create_item( + body={ + 'id': 'wa-doc-{0:03d}'.format(i), + 'state': 'WA', + 'city': 'city-{0}'.format(i), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=['CA'], + max_item_count=7 + ) + pager = query_iterable.by_page() + first_page = list(pager.next()) + self.assertGreater(len(first_page), 0) + token = pager.continuation_token + self.assertIsNotNone(token) + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resume_pager = query_iterable.by_page(token) + resumed_remaining_ids = [] + for page in resume_pager: + resumed_remaining_ids.extend(item['id'] for item in page) + + self.assertListEqual(expected_remaining_ids, resumed_remaining_ids) + + baseline_ids = [ + item['id'] for item in created_collection.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + self.assertListEqual(baseline_ids, fetched_ids) + + created_db.delete_container(created_collection.id) + def test_partition_key_range_overlap_subpartition(self): Id = 'id' MinInclusive = 'minInclusive' diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py index bcccc4735c63..759f3193d108 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py @@ -551,6 +551,115 @@ async def test_partitioned_collection_prefix_partition_query_subpartition_async( await created_db.delete_container(created_collection.id) + async def test_partitioned_collection_full_partition_key_pagination_resume_subpartition_async(self): + created_db = self.database_for_test + collection_id = 'test_partitioned_collection_full_partition_key_resume_MH_async ' + str(uuid.uuid4()) + created_collection = await created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + partition_key = ['CA', 'Oxnard', '93033'] + total_items = 35 + for i in range(total_items): + await created_collection.create_item( + body={ + 'id': 'full-pk-doc-{0:03d}'.format(i), + 'state': partition_key[0], + 'city': partition_key[1], + 'zipcode': partition_key[2] + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=partition_key, + max_item_count=10 + ) + pager = query_iterable.by_page() + first_page = [item async for item in await pager.__anext__()] + assert len(first_page) > 0 + token = pager.continuation_token + assert token is not None + + expected_remaining_ids = [] + async for page in pager: + expected_remaining_ids.extend([item['id'] async for item in page]) + + resumed_remaining_ids = [] + resume_pager = query_iterable.by_page(token) + async for page in resume_pager: + resumed_remaining_ids.extend([item['id'] async for item in page]) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] async for item in created_collection.query_items(query=query, partition_key=partition_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + + await created_db.delete_container(created_collection.id) + + async def test_partitioned_collection_prefix_partition_key_pagination_resume_subpartition_async(self): + created_db = self.database_for_test + collection_id = 'test_partitioned_collection_prefix_partition_key_resume_MH_async ' + str(uuid.uuid4()) + created_collection = await created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + for i in range(30): + await created_collection.create_item( + body={ + 'id': 'ca-doc-{0:03d}'.format(i), + 'state': 'CA', + 'city': 'city-{0}'.format(i % 5), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + for i in range(5): + await created_collection.create_item( + body={ + 'id': 'wa-doc-{0:03d}'.format(i), + 'state': 'WA', + 'city': 'city-{0}'.format(i), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=['CA'], + max_item_count=7 + ) + pager = query_iterable.by_page() + first_page = [item async for item in await pager.__anext__()] + assert len(first_page) > 0 + token = pager.continuation_token + assert token is not None + + expected_remaining_ids = [] + async for page in pager: + expected_remaining_ids.extend([item['id'] async for item in page]) + + resumed_remaining_ids = [] + resume_pager = query_iterable.by_page(token) + async for page in resume_pager: + resumed_remaining_ids.extend([item['id'] async for item in page]) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] async for item in created_collection.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + + await created_db.delete_container(created_collection.id) + async def test_partition_key_range_subpartition_overlap(self): Id = 'id' MinInclusive = 'minInclusive' diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py new file mode 100644 index 000000000000..303a4f125556 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -0,0 +1,1582 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Unit tests for ``azure.cosmos._routing.feed_range_continuation``. + + +* ``TestTokenRoundTrip`` - ``_decode_token(_encode_token(p))`` returns + a structurally-equivalent dict; the wire form is valid base64 of + valid JSON; the JSON contains the five envelope keys (``v`` / ``cr`` / + ``qh`` / ``frh`` / ``c``) and a per-entry ``bc`` inside each + ``c[i]``. The wire format has NO privileged "current" slot — iteration + position is reconstructed in memory from ``c[0]``. +* ``TestVersionMismatchRejected`` - a token whose ``v`` field is set + but is not the SDK's current version raises ``ValueError`` with a + message naming both the offending and the supported version. +* ``TestIdentityFingerprintMismatch`` - a valid v=1 token whose ``cr`` + / ``qh`` / ``frh`` fingerprints disagree with the current request + raises ``ValueError`` with a message naming the failing field. + (Validation lives in the call sites; this test exercises the same + hash-based equality the call sites use.) +* ``TestExplodeOnMultiOverlap`` - when a saved feedrange resolves to + more than one physical partition on resume (the post-split case), + the call site must slice the feedrange into one sub-feedrange per + child before POSTing. These tests pin the geometry of that slice + without touching the network. +""" + +import base64 +import json + +import pytest + +from azure.cosmos import _base +from azure.cosmos import http_constants +from azure.cosmos._query_aggregate_utils import ( + _AggregatePartialClassification, + _classify_aggregate_partial, + _extract_outer_select_value_projection, + _get_select_value_aggregate_function, + _strip_sql_block_comments, +) +from azure.cosmos._routing import routing_range +from azure.cosmos._routing.feed_range_continuation import ( + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS, + _FeedRangePaginationState, + _apply_feedrange_request_headers, + _build_outbound_token, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _encode_token, + _hash_feed_range, + _hash_query_spec, + _stable_hash_128, + _normalize_max_item_count, + _should_bridge_legacy_continuation, + _increment_explode_iterations_or_raise, + _write_query_outbound_continuation, + _update_no_progress_page_count, + _validate_token_identity, + _FIELD_BACKEND_CONTINUATION, + _FIELD_COLLECTION_RID, + _FIELD_CONTINUATIONS, + _FIELD_FEEDRANGE_HASH, + _FIELD_QUERY_HASH, + _FIELD_VERSION, + _TOKEN_VERSION, +) + + +# Fixed inputs reused across the round-trip / mismatch tests so each +# assertion compares against a known-good baseline. +# cspell:ignore AOXB BFFFFFFFFFFFFFFF BAAAAAAAAAA +_RID = "Yxs1AOXBSp4=" +_QUERY = {"query": "SELECT * FROM c WHERE c.x = @x", + "parameters": [{"name": "@x", "value": 7}]} +_FEED_RANGE = routing_range.Range( + range_min="3FFFFFFFFFFFFFFF", + range_max="BFFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_HEAD_FEEDRANGE = routing_range.Range( + range_min="3FFFFFFFFFFFFFFF", + range_max="7FFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_REMAINING_FEEDRANGE = routing_range.Range( + range_min="7FFFFFFFFFFFFFFF", + range_max="BFFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_BACKEND_CONT = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567" + + +def _mk_range(mn: str, mx: str) -> routing_range.Range: + return routing_range.Range(range_min=mn, range_max=mx, isMinInclusive=True, isMaxInclusive=False) + + +def _make_valid_token_payload() -> dict: + """Build a structurally-complete v=1 token payload over the fixtures. + + The wire format is a single ordered ``c`` list of + ``{min, max, bc}`` entries with no privileged "current" slot — + iteration position is reconstructed in memory from ``c[0]``. Each + entry carries its own ``bc``; the sequential loop only ever sets a + non-null ``bc`` for ``c[0]`` and leaves later entries' ``bc`` null. + """ + return { + _FIELD_VERSION: _TOKEN_VERSION, + _FIELD_COLLECTION_RID: _RID, + _FIELD_QUERY_HASH: _hash_query_spec(_QUERY), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(_FEED_RANGE), + _FIELD_CONTINUATIONS: [ + { + "min": _HEAD_FEEDRANGE.min, + "max": _HEAD_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: _BACKEND_CONT, + }, + { + "min": _REMAINING_FEEDRANGE.min, + "max": _REMAINING_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: None, + }, + ], + } + + +class TestStableHash128MurmurRegression: + """Pin one MurmurHash3-128 output as a regression guard against + accidental algorithm drift. The other hash tests only assert + determinism / inequality between distinct inputs; this one pins + exact bytes for one fixed input.""" + + def test_known_input_produces_known_murmur_digest(self): + assert _stable_hash_128(b"feed_range_continuation_token_regression") == ( + "ce0130ea460342256309b38dfdbc9c50" + ) + + +# ---------------------------------------------------------------------- # +# Token round-trip +# ---------------------------------------------------------------------- # +class TestTokenRoundTrip: + """``_encode_token`` -> ``_decode_token`` is structurally lossless and + the wire form is base64-encoded JSON containing all seven required + fields.""" + + def test_round_trip_preserves_all_fields(self): + payload = _make_valid_token_payload() + wire = _encode_token(payload) + decoded = _decode_token(wire) + assert decoded == payload + + def test_wire_form_is_base64_of_json(self): + payload = _make_valid_token_payload() + wire = _encode_token(payload) + # Wire form must be ASCII-safe and base64-decodable; the decoded + # bytes must be valid UTF-8 JSON; the JSON must be a dict. + raw = base64.b64decode(wire, validate=True) + as_json = json.loads(raw.decode("utf-8")) + assert isinstance(as_json, dict) + + def test_wire_form_contains_five_envelope_keys_and_per_entry_bc(self): + # The envelope itself has FIVE top-level keys (no top-level + # ``bc``, no privileged ``cf``/``rf`` split); ``bc`` lives + # inside each ``c[i]`` so a future non-sequential / parallel + # loop can record one backend continuation per sub-range + # without a wire-format bump. + payload = _make_valid_token_payload() + wire = _encode_token(payload) + decoded_json = json.loads(base64.b64decode(wire, validate=True).decode("utf-8")) + envelope_required = { + _FIELD_VERSION, + _FIELD_COLLECTION_RID, + _FIELD_QUERY_HASH, + _FIELD_FEEDRANGE_HASH, + _FIELD_CONTINUATIONS, + } + assert envelope_required == set(decoded_json.keys()) + assert _FIELD_BACKEND_CONTINUATION not in decoded_json, ( + "envelope must NOT carry a top-level 'bc'; bc is per-entry" + ) + assert "cf" not in decoded_json, ( + "envelope must NOT carry a privileged 'cf' slot" + ) + assert "rf" not in decoded_json, ( + "envelope must NOT carry a 'rf' tail; sub-ranges live in a single 'c' list" + ) + assert isinstance(decoded_json[_FIELD_CONTINUATIONS], list) + assert len(decoded_json[_FIELD_CONTINUATIONS]) >= 1 + for entry in decoded_json[_FIELD_CONTINUATIONS]: + assert _FIELD_BACKEND_CONTINUATION in entry + + def test_build_outbound_token_emits_valid_token(self): + wire = _build_outbound_token( + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + entries=[ + (_HEAD_FEEDRANGE, _BACKEND_CONT), + (_REMAINING_FEEDRANGE, None), + ], + ) + decoded = _decode_token(wire) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + assert decoded[_FIELD_COLLECTION_RID] == _RID + assert decoded[_FIELD_QUERY_HASH] == _hash_query_spec(_QUERY) + assert decoded[_FIELD_FEEDRANGE_HASH] == _hash_feed_range(_FEED_RANGE) + # Head of the ``c`` list == the in-flight slice; tail == queued. + assert decoded[_FIELD_CONTINUATIONS] == [ + { + "min": _HEAD_FEEDRANGE.min, + "max": _HEAD_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: _BACKEND_CONT, + }, + { + "min": _REMAINING_FEEDRANGE.min, + "max": _REMAINING_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: None, + }, + ] + assert _FIELD_BACKEND_CONTINUATION not in decoded + assert "cf" not in decoded + assert "rf" not in decoded + + + def test_per_entry_backend_continuations_coexist(self): + # The shape that motivated the flat ``c`` list: a future + # non-sequential / parallel-fetch loop emits a token where + # multiple entries each carry their own non-null backend + # continuation. Today's sequential loop never produces this + # state, but the wire shape must already support it so that + # when parallel fetch lands no version bump is needed. + wire = _build_outbound_token( + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + entries=[ + (_HEAD_FEEDRANGE, "B-cont-5"), + (_REMAINING_FEEDRANGE, "A-cont-5"), + ], + ) + decoded = _decode_token(wire) + assert decoded is not None + entries = decoded[_FIELD_CONTINUATIONS] + assert entries[0][_FIELD_BACKEND_CONTINUATION] == "B-cont-5" + assert entries[1][_FIELD_BACKEND_CONTINUATION] == "A-cont-5" + + def test_none_and_empty_inputs_decode_to_none(self): + # An empty / missing continuation must NOT raise - the call site + # treats it as "first call, derive feedranges from the routing map". + assert _decode_token(None) is None + assert _decode_token("") is None + + +# ---------------------------------------------------------------------- # +# Version-mismatch rejection +# ---------------------------------------------------------------------- # +class TestVersionMismatchRejected: + """A token that decodes as our shape but with a non-current ``v`` + raises ``ValueError`` rather than being silently misinterpreted.""" + + def test_future_version_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_VERSION] = 999 + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + msg = str(excinfo.value) + assert "999" in msg + assert str(_TOKEN_VERSION) in msg + + def test_zero_version_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_VERSION] = 0 + wire = _encode_token(payload) + with pytest.raises(ValueError): + _decode_token(wire) + + def test_missing_continuations_list_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS] + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + assert _FIELD_CONTINUATIONS in str(excinfo.value) + + def test_empty_continuations_list_raises(self): + # An empty ``c`` list cannot legitimately appear on the wire: + # the producer clears the outbound continuation header in the + # drained case rather than emitting a token with no entries. + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS] = [] + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + assert _FIELD_CONTINUATIONS in str(excinfo.value) + + +class TestMalformedV1TokenRejected: + """Malformed v1 tokens should raise ValueError at decode time. + + This prevents downstream call-sites from seeing KeyError when indexing + required identity and feedrange fields. + """ + + def test_missing_cr_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_COLLECTION_RID] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "cr" in str(excinfo.value) + + def test_missing_qh_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_QUERY_HASH] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "qh" in str(excinfo.value) + + def test_missing_frh_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_FEEDRANGE_HASH] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "frh" in str(excinfo.value) + + def test_missing_head_min_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS][0]["min"] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "{}[0].min".format(_FIELD_CONTINUATIONS) in str(excinfo.value) + + def test_malformed_tail_entry_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS] = [payload[_FIELD_CONTINUATIONS][0], "not-an-object"] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "{}[1]".format(_FIELD_CONTINUATIONS) in str(excinfo.value) + + def test_non_string_backend_continuation_raises(self): + # ``bc`` lives inside each ``c[i]``; a non-string value must raise. + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS][0][_FIELD_BACKEND_CONTINUATION] = 123 + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + def test_envelope_level_backend_continuation_is_rejected(self): + # An older shape carried ``bc`` at the envelope root. The + # current shape moves ``bc`` inside each ``c[i]`` entry; a + # top-level ``bc`` must be rejected so a token from any earlier + # build fails loudly instead of being silently dropped. + payload = _make_valid_token_payload() + payload[_FIELD_BACKEND_CONTINUATION] = "envelope-level-bc" + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + + def test_missing_per_entry_backend_continuation_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS][0][_FIELD_BACKEND_CONTINUATION] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + +# ---------------------------------------------------------------------- # +# Identity-fingerprint mismatch rejection +# ---------------------------------------------------------------------- # +class TestIdentityFingerprintMismatch: + """A valid v=1 token replayed against a different collection / query / + feed_range produces a fingerprint mismatch the call site rejects. + + The hash helpers are deterministic and the call-site validators in + ``__QueryFeed`` compare ``inbound[_FIELD_*]`` to ``_hash_*(current)`` + and raise ``ValueError`` on mismatch.""" + + def test_collection_rid_mismatch_detected(self): + payload = _make_valid_token_payload() + decoded = _decode_token(_encode_token(payload)) + assert decoded is not None + # Same call-site shape: compare cr to current resource_id. + assert decoded[_FIELD_COLLECTION_RID] == _RID + assert decoded[_FIELD_COLLECTION_RID] != "different-collection-rid==" + + def test_query_text_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": "SELECT * FROM c WHERE c.x = @x AND c.y = 1", + "parameters": _QUERY["parameters"], + }) + assert original != modified, ( + "query-text change must produce a different hash so the " + "call site can reject the resume") + + def test_query_parameter_value_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": _QUERY["query"], + "parameters": [{"name": "@x", "value": 8}], + }) + assert original != modified + + def test_query_parameter_name_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": _QUERY["query"], + "parameters": [{"name": "@y", "value": 7}], + }) + assert original != modified + + def test_query_string_form_hashes_consistently(self): + # When the caller passes a plain string (no parameters) the hash + # must still be deterministic and stable. + h1 = _hash_query_spec("SELECT * FROM c") + h2 = _hash_query_spec("SELECT * FROM c") + h3 = _hash_query_spec("SELECT VALUE c FROM c") + assert h1 == h2 + assert h1 != h3 + + def test_feed_range_change_changes_hash(self): + original = _hash_feed_range(_FEED_RANGE) + wider = _hash_feed_range(routing_range.Range( + range_min=_FEED_RANGE.min, + range_max="FFFFFFFFFFFFFFFF", + isMinInclusive=True, isMaxInclusive=False, + )) + narrower = _hash_feed_range(routing_range.Range( + range_min=_FEED_RANGE.min, + range_max="9FFFFFFFFFFFFFFF", + isMinInclusive=True, isMaxInclusive=False, + )) + assert original != wider + assert original != narrower + assert wider != narrower + + def test_feed_range_hash_is_stable(self): + # Same feed_range -> same hash on every call (no random state). + h1 = _hash_feed_range(_FEED_RANGE) + h2 = _hash_feed_range(_FEED_RANGE) + assert h1 == h2 + + def test_feed_range_inclusivity_normalization_yields_same_hash(self): + # Hashing is based on the logical normalized EPK interval, so + # equivalent ranges with different bound inclusivity spellings + # must produce the same hash. + non_normalized = routing_range.Range( + range_min="0000000000000000", + range_max="7FFFFFFFFFFFFFFF", + isMinInclusive=False, + isMaxInclusive=True, + ) + normalized_image = non_normalized.to_normalized_range() + # Sanity: normalization actually changed something so the test + # is exercising the equivalence, not a no-op. + assert (normalized_image.isMinInclusive, normalized_image.isMaxInclusive) == (True, False) + + # The two forms must hash equal because they describe the same + # logical [min, max) interval after normalization. + assert _hash_feed_range(non_normalized) == _hash_feed_range(normalized_image) + + # And feeding the function the already-normalized form must + # yield the same digest a second time (idempotent). + assert _hash_feed_range(non_normalized) == _hash_feed_range( + routing_range.Range( + range_min=normalized_image.min, + range_max=normalized_image.max, + isMinInclusive=True, + isMaxInclusive=False, + ) + ) + + def test_call_site_replay_against_other_collection_raises(self): + """Drive the production validator (``_validate_token_identity``) + with a token built for ``_RID`` and resume against a different + collection rid. It must raise ``ValueError`` whose message names + the failing field, matching the call-site contract in + ``__QueryFeed`` (sync and async).""" + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id="different-collection-rid==", + query=_QUERY, + feed_range_epk=_FEED_RANGE, + ) + assert "collection" in str(excinfo.value).lower() + + def test_call_site_replay_with_different_query_raises(self): + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id=_RID, + query={"query": "SELECT c.id FROM c", "parameters": []}, + feed_range_epk=_FEED_RANGE, + ) + assert "query" in str(excinfo.value).lower() + + def test_call_site_replay_with_different_feed_range_raises(self): + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + other_feed_range = routing_range.Range( + range_min="0000000000000000", + range_max="3FFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, + ) + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id=_RID, + query=_QUERY, + feed_range_epk=other_feed_range, + ) + assert "feed_range" in str(excinfo.value).lower() + + + +# ---------------------------------------------------------------------- # +# Explode-on-multi-overlap - post-split fan-out unit contract +# ---------------------------------------------------------------------- # +class TestExplodeOnMultiOverlap: + """Post-split fan-out contract for the resume path. + + Setup: a saved feedrange ``[A, C)`` lived inside one physical + partition on the day the token was emitted. By the time the token + is resumed, that partition has split at ``B`` into two children + ``X1 = [A, B)`` and ``X2 = [B, C)``. Re-resolving the saved + feedrange against the live routing map now returns two overlaps, + not one. + + If the call site just POSTed once against ``X1`` with + ``EndEpkString = C``, every row that physically lives on ``X2`` + would be silently dropped - the backend's EPK filter only returns + rows on the partition the request was routed to. + + The contract in ``__QueryFeed`` (sync and async): when + ``len(overlapping) > 1``, hand the saved feedrange and the new + children to ``_derive_initial_feedranges`` to get one sub-feedrange + per child (each the intersection of the child's range with the + saved feedrange). The first becomes the new ``head_feedrange``, + the rest are prepended to ``pending_feedranges``, and the + parent's backend continuation is dropped (it referenced the old + partition's id). The next loop iteration sees a single overlap and + falls through to the normal single-partition POST. + + The tests below pin four properties of that slice: + + * one sub-feedrange per child, + * sub-feedranges cover the saved feedrange end-to-end with no + gap and no overlap, + * order is by EPK ``min`` regardless of input order, + * each sub-feedrange resolves to exactly one child on the next + loop iteration (so the slice branch does not re-fire).""" + + @staticmethod + def _pkr(pkr_id: str, mn: str, mx: str) -> dict: + # The minimal partition_key_range dict shape that the routing + # map provider hands back. _build_scope_from_overlaps and + # _derive_initial_feedranges both consume this shape directly. + return {"id": pkr_id, "minInclusive": mn, "maxExclusive": mx} + + def test_two_child_split_slices_into_two_sub_feedranges(self): + # Saved feedrange covers the whole of the pre-split parent. + # After the split it resolves to two children X1 and X2; the + # slice must hand back one sub-feedrange per child. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + x1 = self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000") + x2 = self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0") + + sub_feedranges = _derive_initial_feedranges(saved_feedrange, [x1, x2]) + + assert len(sub_feedranges) == 2, ( + "Day-N resolution returned 2 children but the slice " + "produced {} sub-feedranges".format(len(sub_feedranges))) + assert (sub_feedranges[0].min, sub_feedranges[0].max) == ( + "05C1D9D533F364", "05C1D9E4000000"), ( + "first sub-feedrange should be the X1 slice of the saved feedrange") + assert (sub_feedranges[1].min, sub_feedranges[1].max) == ( + "05C1D9E4000000", "05C1D9F59FF5A0"), ( + "second sub-feedrange should be the X2 slice of the saved feedrange") + + def test_sub_feedranges_partition_parent_exactly(self): + # A wider variant: the saved feedrange sits inside an even + # bigger old partition that has since split into THREE children. + # The slice must still cover the saved feedrange end-to-end - + # every row that was returnable under the old layout must still + # be reachable under the new one, no gap (= missing rows) and + # no overlap (= duplicates). + saved_feedrange = routing_range.Range( + range_min="20", range_max="E0", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("c1", "00", "55"), + self._pkr("c2", "55", "AA"), + self._pkr("c3", "AA", "FF"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + bounds = [(s.min, s.max) for s in sub_feedranges] + assert bounds == [("20", "55"), ("55", "AA"), ("AA", "E0")], ( + "sub-feedranges must be the intersections of each child with " + "the saved feedrange; got {}".format(bounds)) + # First sub-feedrange starts where the saved feedrange starts; + # last one ends where it ends. Anything else loses rows at the + # edges. + assert bounds[0][0] == saved_feedrange.min + assert bounds[-1][1] == saved_feedrange.max + # And the sub-feedranges butt up against each other with no gap + # and no overlap. + for i in range(len(bounds) - 1): + assert bounds[i][1] == bounds[i + 1][0], ( + "sub-feedranges {} and {} have a gap or overlap at the " + "boundary; rows in between would be missed or " + "duplicated".format(bounds[i], bounds[i + 1])) + + def test_sub_feedranges_are_deterministically_ordered(self): + # The routing map provider doesn't promise any particular order + # when it returns the children. The call site prepends the + # leftover sub-feedranges to pending_feedranges, so if the + # order depended on what the provider happened to return, two + # different SDK processes resuming the same token could end up + # walking the children in different orders - and emit different + # outbound tokens halfway through. Pin EPK-min order regardless + # of input order. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + x1 = self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000") + x2 = self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0") + + forward = _derive_initial_feedranges(saved_feedrange, [x1, x2]) + reverse = _derive_initial_feedranges(saved_feedrange, [x2, x1]) + + assert ([(r.min, r.max) for r in forward] + == [(r.min, r.max) for r in reverse]), ( + "slice result depended on input child order; resuming the " + "same token from two processes would diverge") + + def test_each_sub_feedrange_resolves_to_a_single_child(self): + # Why the slice is correct end-to-end: after slicing, the + # NEW head_feedrange is X1's slice, and the next iteration + # of the __QueryFeed loop re-resolves it against the routing + # map. That re-resolution must come back with exactly one + # overlap (X1) - otherwise we'd loop into the slice branch a + # second time. Same for X2 once X1 is drained. This pins the + # invariant that each sub-feedrange routes cleanly to one + # partition, which is what lets the rest of the loop fall + # through to the single-partition POST. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000"), + self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + for sb in sub_feedranges: + owning_child = next(c for c in children + if c["minInclusive"] <= sb.min < c["maxExclusive"]) + overlaps, scope = _build_scope_from_overlaps([owning_child], sb) + assert len(overlaps) == 1, ( + "sub-feedrange [{}, {}) re-resolved to {} overlaps; the next " + "loop iteration would slice again".format( + sb.min, sb.max, len(overlaps))) + assert (scope.min, scope.max) == (sb.min, sb.max), ( + "sub-feedrange [{}, {}) routes to a partition whose scope is " + "[{}, {}); the EPK filter would over-fetch or " + "under-fetch".format(sb.min, sb.max, scope.min, scope.max)) + + def test_no_split_single_overlap_is_not_sliced(self): + # The "feed range fits inside one child" path: a feedrange that + # sits entirely inside a single child still resolves to one + # overlap. The slice branch is gated by `if len(overlapping) > 1`, + # so the call site goes straight to the single-partition POST. + # This is the negative control - verifying nothing in our slice + # helpers fires spuriously for the common safe case. + feed_range = routing_range.Range( + range_min="40", range_max="60", + isMinInclusive=True, isMaxInclusive=False) + overlaps, scope = _build_scope_from_overlaps( + [self._pkr("c1", "00", "80")], feed_range) + assert len(overlaps) == 1 + assert (scope.min, scope.max) == ("00", "80"), ( + "single-overlap re-resolution returned the wrong " + "partition scope") + + def test_three_child_split_slices_into_three(self): + # The 1->2 split is the common case, but nothing in the design + # caps it there - over enough wall-clock time, X1 and X2 can + # themselves split, and a saved feedrange from before all of + # those splits will resolve to 3+ overlaps. Pin that the slice + # handles N children the same way it handles 2: one + # sub-feedrange per child, in EPK order, covering the saved + # feedrange. + saved_feedrange = routing_range.Range( + range_min="00", range_max="FF", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("c1", "00", "55"), + self._pkr("c2", "55", "AA"), + self._pkr("c3", "AA", "FF"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + assert len(sub_feedranges) == 3 + assert [(s.min, s.max) for s in sub_feedranges] == [ + ("00", "55"), ("55", "AA"), ("AA", "FF"), + ] + + +# ---------------------------------------------------------------------- # +# max_item_count normalization +# ---------------------------------------------------------------------- # +class TestNormalizeMaxItemCount: + """``_normalize_max_item_count`` collapses unset / non-numeric / non-positive + values to ``None`` (unbounded) and passes positive ints through unchanged. + + The pagination loop interprets ``None`` as "no client-side cap" and any + positive int as the per-page item limit. A zero or negative cap would make + the loop exit before issuing any POST while still emitting a + continuation token, leaving the caller in an empty-page-with-continuation + cycle - so those cases must be normalized to ``None``.""" + + def test_none_passes_through(self): + assert _normalize_max_item_count(None) is None + + def test_positive_int_passes_through(self): + assert _normalize_max_item_count(5) == 5 + + def test_positive_str_is_parsed(self): + assert _normalize_max_item_count("25") == 25 + + def test_zero_is_treated_as_unbounded(self): + assert _normalize_max_item_count(0) is None + + def test_negative_is_treated_as_unbounded(self): + assert _normalize_max_item_count(-1) is None + + def test_non_numeric_is_treated_as_unbounded(self): + assert _normalize_max_item_count("not-a-number") is None + + def test_object_is_treated_as_unbounded(self): + assert _normalize_max_item_count(object()) is None + + +# ---------------------------------------------------------------------- # +# Request-header shaping +# ---------------------------------------------------------------------- # +class TestApplyFeedrangeRequestHeaders: + """``_apply_feedrange_request_headers`` sets and clears routing/page/token + headers correctly for both full-partition and sub-range requests.""" + + @pytest.mark.parametrize( + "head_feedrange,expect_epk_headers", + [ + # full-partition request -> EPK headers must be cleared + (_mk_range("10", "20"), False), + # strict sub-range request -> EPK headers must be stamped + (_mk_range("12", "18"), True), + ], + ) + def test_epk_headers_match_full_vs_subrange(self, head_feedrange, expect_epk_headers): + req_headers = { + # pre-populate with stale values to prove clear behavior + http_constants.HttpHeaders.StartEpkString: "stale-start", + http_constants.HttpHeaders.EndEpkString: "stale-end", + } + overlapping = [{"id": "7", "minInclusive": "10", "maxExclusive": "20"}] + partition_scope = _mk_range("10", "20") + + _apply_feedrange_request_headers( + req_headers=req_headers, + overlapping=overlapping, + partition_scope=partition_scope, + head_feedrange=head_feedrange, + page_size_hint=None, + inbound_continuation=None, + ) + + assert req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] == "7" + assert req_headers[http_constants.HttpHeaders.ReadFeedKeyType] == "EffectivePartitionKeyRange" + if expect_epk_headers: + assert req_headers[http_constants.HttpHeaders.StartEpkString] == head_feedrange.min + assert req_headers[http_constants.HttpHeaders.EndEpkString] == head_feedrange.max + else: + assert http_constants.HttpHeaders.StartEpkString not in req_headers + assert http_constants.HttpHeaders.EndEpkString not in req_headers + + @pytest.mark.parametrize( + "page_size_hint,inbound_continuation,expect_page_size,expect_continuation", + [ + (5, "abc", True, True), + (None, "abc", False, True), + (5, None, True, False), + (None, None, False, False), + ], + ) + def test_page_size_and_continuation_are_set_or_cleared( + self, + page_size_hint, + inbound_continuation, + expect_page_size, + expect_continuation, + ): + req_headers = { + # pre-populate stale values; helper should clear when args are None + http_constants.HttpHeaders.PageSize: "999", + http_constants.HttpHeaders.Continuation: "stale-cont", + } + overlapping = [{"id": "9", "minInclusive": "30", "maxExclusive": "40"}] + partition_scope = _mk_range("30", "40") + head_feedrange = _mk_range("30", "40") + + _apply_feedrange_request_headers( + req_headers=req_headers, + overlapping=overlapping, + partition_scope=partition_scope, + head_feedrange=head_feedrange, + page_size_hint=page_size_hint, + inbound_continuation=inbound_continuation, + ) + + if expect_page_size: + assert req_headers[http_constants.HttpHeaders.PageSize] == str(page_size_hint) + else: + assert http_constants.HttpHeaders.PageSize not in req_headers + + if expect_continuation: + assert req_headers[http_constants.HttpHeaders.Continuation] == inbound_continuation + else: + assert http_constants.HttpHeaders.Continuation not in req_headers + + +class TestBudgetCounting: + """Page-item counting treats aggregate partial rows as merge fragments.""" + + def test_standard_documents_consume_page_item_limit(self): + partial_result = {"Documents": [{"id": "1"}, {"id": "2"}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT * FROM c") == 2 + + def test_multi_element_documents_in_aggregate_context_consume_page_item_limit(self): + partial_result = {"Documents": [{"_aggregate": {"count": 7}}, {"_aggregate": {"count": 3}}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT COUNT(1) FROM c") == 2 + + def test_object_aggregate_partial_does_not_consume_page_item_limit(self): + partial_result = {"Documents": [{"_aggregate": {"count": 7}}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT COUNT(1) FROM c") == 0 + + def test_value_aggregate_partial_does_not_consume_page_item_limit(self): + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE COUNT(1) FROM c") == 0 + + def test_value_non_aggregate_numeric_row_consumes_page_item_limit(self): + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE c.value FROM c") == 1 + + def test_value_non_aggregate_boolean_row_consumes_page_item_limit(self): + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE c.flag FROM c") == 1 + + +class TestAggregateMergeConsistency: + """Page-item counting and merge logic should classify aggregate fragments the same way.""" + + def test_value_count_boolean_fragments_are_not_treated_as_numeric_aggregates(self): + query = "SELECT VALUE COUNT(1) > 0 FROM c" + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, query) == 1 + + merged = _base._merge_query_results({"Documents": [True]}, {"Documents": [True]}, query) + assert merged["Documents"] == [True, True] + + def test_value_count_numeric_fragments_are_treated_as_aggregates(self): + query = "SELECT VALUE COUNT(1) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [10] + + def test_value_min_numeric_fragments_are_merged_with_min(self): + query = "SELECT VALUE MIN(c.score) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [3] + + def test_value_max_numeric_fragments_are_merged_with_max(self): + query = "SELECT VALUE MAX(c.score) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [7] + + def test_value_min_max_three_way_merge(self): + min_query = "SELECT VALUE MIN(c.score) FROM c" + merged_min = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, min_query) + merged_min = _base._merge_query_results(merged_min, {"Documents": [11]}, min_query) + assert merged_min["Documents"] == [3] + + max_query = "SELECT VALUE MAX(c.score) FROM c" + merged_max = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, max_query) + merged_max = _base._merge_query_results(merged_max, {"Documents": [11]}, max_query) + assert merged_max["Documents"] == [11] + + def test_value_boolean_non_aggregate_fragments_are_concatenated(self): + query = "SELECT VALUE c.flag FROM c" + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, query) == 1 + + merged = _base._merge_query_results({"Documents": [True]}, {"Documents": [True]}, query) + assert merged["Documents"] == [True, True] + + def test_value_numeric_non_aggregate_fragments_are_concatenated(self): + """Regression: numeric VALUE rows must concatenate across partitions, not sum.""" + query = "SELECT VALUE c.score FROM c" + + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + assert _get_select_value_aggregate_function(query) is None + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [7, 3] + + def test_value_float_non_aggregate_fragments_are_concatenated(self): + """Same regression for floats; they must concatenate, not collapse.""" + query = "SELECT VALUE c.ratio FROM c" + + assert _count_page_items_from_partial_result({"Documents": [1.5]}, query) == 1 + merged = _base._merge_query_results( + {"Documents": [1.5]}, {"Documents": [2.25]}, query, + ) + assert merged["Documents"] == [1.5, 2.25] + + def test_value_numeric_non_aggregate_three_way_merge_is_concatenated(self): + """Three-partition fan-in must preserve order and avoid numeric collapse.""" + query = "SELECT VALUE c.score FROM c" + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + merged = _base._merge_query_results(merged, {"Documents": [11]}, query) + assert merged["Documents"] == [7, 3, 11] + + def test_value_merge_raises_if_aggregate_function_detection_is_missing(self, monkeypatch): + query = "SELECT VALUE COUNT(1) FROM c" + monkeypatch.setattr(_base, "_get_select_value_aggregate_function", lambda _: None) + + with pytest.raises(ValueError) as excinfo: + _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + + assert "VALUE aggregate classification" in str(excinfo.value) + + def test_value_avg_merge_raises_as_unsupported(self): + query = "SELECT VALUE AVG(c.value) FROM c" + + with pytest.raises(ValueError) as excinfo: + _base._merge_query_results({"Documents": [7.0]}, {"Documents": [3.0]}, query) + + assert "VALUE AVG aggregate merge" in str(excinfo.value) + + def test_raise_query_merge_value_error_rewrites_value_avg_message(self): + original = ValueError("VALUE AVG aggregate merge across partitions is not supported client-side.") + + with pytest.raises(ValueError) as excinfo: + _base._raise_query_merge_value_error(original) + + assert "SELECT VALUE AVG(...)" in str(excinfo.value) + assert "range-scoped pagination" in str(excinfo.value) + + def test_raise_query_merge_value_error_preserves_other_value_errors(self): + original = ValueError("Invariant violation: VALUE aggregate classification requires a recognized aggregate function.") + + with pytest.raises(ValueError) as excinfo: + _base._raise_query_merge_value_error(original) + + assert str(excinfo.value) == str(original) + + def test_value_aggregate_detection_allows_space_before_open_paren(self): + query = "SELECT VALUE COUNT (1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 0 + + def test_value_aggregate_detection_does_not_match_function_substrings(self): + query = "SELECT VALUE MYCOUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_value_aggregate_detection_ignores_subquery_aggregate_tokens(self): + query = "SELECT VALUE c.name FROM c WHERE EXISTS(SELECT VALUE COUNT(1) FROM d)" + assert _get_select_value_aggregate_function(query) is None + + def test_numeric_value_row_with_subquery_aggregate_still_consumes_page_item(self): + query = "SELECT VALUE c.id FROM c WHERE EXISTS(SELECT VALUE COUNT(1) FROM d)" + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_numeric_value_row_with_projection_subquery_aggregate_still_consumes_page_item(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_numeric_value_row_with_array_projection_subquery_still_consumes_page_item(self): + query = "SELECT VALUE ARRAY(SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + +class TestSelectValueProjectionParser: + @pytest.mark.parametrize( + "normalized_query,expected_projection", + [ + ("SELECT VALUE COUNT(1) FROM C", "COUNT(1)"), + ("SELECT VALUE (SELECT VALUE COUNT(1) FROM D) FROM C", "(SELECT VALUE COUNT(1) FROM D)"), + ("SELECT VALUE C.FROMAGE FROM C", "C.FROMAGE"), + ("SELECT VALUE COUNT(1)", None), + ("SELECT VALUE (COUNT(1) FROM C", None), + ], + ) + def test_extract_outer_select_value_projection_edges(self, normalized_query, expected_projection): + assert _extract_outer_select_value_projection(normalized_query) == expected_projection + + def test_projection_level_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d) FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_projection_level_in_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_array_projection_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE ARRAY(SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_nested_select_value_in_where_subquery_does_not_drive_outer_detection(self): + query = "SELECT c.count FROM c WHERE c.count IN (SELECT VALUE COUNT(1) FROM c)" + assert _get_select_value_aggregate_function(query) is None + + +class TestAggregateClassificationHeuristics: + def test_block_comment_prefix_does_not_drive_outer_select_value_detection(self): + query = "/* SELECT VALUE COUNT(1) */ SELECT VALUE c.x FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_value_aggregate_detected_with_comment_between_select_and_value(self): + query = "SELECT /* comment */ VALUE COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_value_aggregate_detected_with_comment_between_value_and_function(self): + query = "SELECT VALUE /* comment */ COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_comment_with_fake_from_does_not_truncate_projection(self): + query = "SELECT VALUE /* FROM d */ COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_block_comment_inside_string_literal_is_not_stripped(self): + query = "SELECT VALUE '/* COUNT(1) */' FROM c" + stripped = _strip_sql_block_comments(query) + assert "/* COUNT(1) */" in stripped + + def test_value_projection_with_property_named_count_is_not_aggregate(self): + query = "SELECT VALUE c.COUNT FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [42.5]}, query) == 1 + + def test_classify_aggregate_partial_excludes_boolean_value_rows(self): + query = "SELECT VALUE COUNT(1) FROM c" + docs = [True] + assert _classify_aggregate_partial(docs, query) == _AggregatePartialClassification.NONE + + def test_classify_aggregate_partial_treats_non_aggregate_float_as_none(self): + query = "SELECT VALUE c.price FROM c" + docs = [42.5] + assert _classify_aggregate_partial(docs, query) == _AggregatePartialClassification.NONE + + +class TestEmptyPageStallCounter: + """No-progress guard only counts empty pages that still carry continuation.""" + + def test_increments_on_empty_page_with_continuation(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 3, + page_items_returned=0, + previous_feedrange=head_feedrange, + previous_backend_continuation="token", + head_feedrange=head_feedrange, + head_backend_continuation="token", + ) == 4 + + def test_increments_when_equal_bounds_are_different_objects(self): + # Guard against regressions where two equivalent ranges are reconstructed + # as distinct objects between loop iterations. + assert _update_no_progress_page_count( + 3, + page_items_returned=0, + previous_feedrange=_mk_range("10", "20"), + previous_backend_continuation="token", + head_feedrange=_mk_range("10", "20"), + head_backend_continuation="token", + ) == 4 + + def test_resets_when_items_are_returned(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 5, + page_items_returned=1, + previous_feedrange=head_feedrange, + previous_backend_continuation="token", + head_feedrange=head_feedrange, + head_backend_continuation="token", + ) == 0 + + def test_resets_when_continuation_is_none(self): + assert _update_no_progress_page_count( + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES - 1, + page_items_returned=0, + previous_feedrange=_mk_range("10", "20"), + previous_backend_continuation="token", + head_feedrange=_mk_range("20", "30"), + head_backend_continuation=None, + ) == 0 + + def test_resets_when_continuation_advances(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 8, + page_items_returned=0, + previous_feedrange=head_feedrange, + previous_backend_continuation="token-1", + head_feedrange=head_feedrange, + head_backend_continuation="token-2", + ) == 0 + + +class TestExplodeIterationGuard: + def test_increments_under_limit(self): + assert _increment_explode_iterations_or_raise(0) == 1 + + def test_raises_over_limit(self): + with pytest.raises(RuntimeError) as excinfo: + _increment_explode_iterations_or_raise(_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS) + assert "split re-resolution" in str(excinfo.value) + + +class TestFeedRangePaginationState: + """Unit tests for the shared pagination state machine. + + The state machine holds a single ordered queue of + ``(sub-range, backend continuation)`` pairs — modeled on Java's + ``Queue``. There is no "current vs. + remaining" split; the head is just ``queue[0]`` and the tail is + ``queue[1:]`` by virtue of being later in the same deque. + """ + + @staticmethod + def _pkr(pkr_id: str, mn: str, mx: str) -> dict: + return {"id": pkr_id, "minInclusive": mn, "maxExclusive": mx} + + @staticmethod + def _bounds(rng: routing_range.Range) -> tuple[str, str]: + return rng.min, rng.max + + @classmethod + def _queue_bounds(cls, state) -> list: + """All ``(min, max, bc)`` triples in the queue, in head-first order.""" + return [(r.min, r.max, bc) for r, bc in state.queue] + + def test_from_derived_feedranges_empty_initializes_done_state(self): + state = _FeedRangePaginationState.from_derived_feedranges([], page_size_hint=5) + assert list(state.queue) == [] + assert state.head_range is None + assert state.head_bc is None + assert state.page_size_hint == 5 + + def test_from_derived_feedranges_seeds_queue_with_no_continuations(self): + a = _mk_range("00", "40") + b = _mk_range("40", "80") + state = _FeedRangePaginationState.from_derived_feedranges([a, b], page_size_hint=7) + assert self._queue_bounds(state) == [("00", "40", None), ("40", "80", None)] + assert self._bounds(state.head_range) == ("00", "40") + assert state.head_bc is None + assert state.page_size_hint == 7 + + def test_from_single_feedrange_with_continuation_seeds_head_bc(self): + head = _mk_range("AA", "BB") + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + head, + "legacy-token-1", + page_size_hint=11, + ) + assert self._queue_bounds(state) == [("AA", "BB", "legacy-token-1")] + assert self._bounds(state.head_range) == ("AA", "BB") + assert state.head_bc == "legacy-token-1" + assert state.page_size_hint == 11 + + def test_from_single_feedrange_with_continuation_allows_null(self): + head = _mk_range("AA", "BB") + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + head, + None, + page_size_hint=None, + ) + assert self._queue_bounds(state) == [("AA", "BB", None)] + assert state.head_bc is None + + @pytest.mark.parametrize( + "queue,page_size_hint,expected", + [ + ([], None, False), + ([(_mk_range("00", "40"), None)], 0, True), + ([(_mk_range("00", "40"), None)], -1, True), + ([(_mk_range("00", "40"), None)], None, True), + ([(_mk_range("00", "40"), None)], 1, True), + ], + ) + def test_can_issue_request_boundaries(self, queue, page_size_hint, expected): + state = _FeedRangePaginationState( + queue=queue, + page_size_hint=page_size_hint, + ) + assert state.can_issue_request() is expected + + def test_from_inbound_parses_queue_and_continuations(self): + # The wire format is a single ordered ``c`` list. The state + # machine loads it as a single queue of ``(range, bc)`` pairs + # — ``c[0]`` becomes the head, with no privileged "current vs. + # remaining" split. + inbound = { + _FIELD_CONTINUATIONS: [ + {"min": "00", "max": "40", _FIELD_BACKEND_CONTINUATION: "token-1"}, + {"min": "40", "max": "80", _FIELD_BACKEND_CONTINUATION: None}, + ], + } + state = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=9) + assert self._queue_bounds(state) == [ + ("00", "40", "token-1"), + ("40", "80", None), + ] + assert self._bounds(state.head_range) == ("00", "40") + assert state.head_bc == "token-1" + assert state.page_size_hint == 9 + + def test_from_inbound_preserves_per_entry_backend_continuations(self): + # Future non-sequential / parallel-loop case: a saved token + # where multiple ``c[i]`` entries each carry their own non-null + # backend continuation. All must round-trip into the queue + # untouched. + inbound = { + _FIELD_CONTINUATIONS: [ + {"min": "00", "max": "40", _FIELD_BACKEND_CONTINUATION: "B-cont-5"}, + {"min": "40", "max": "80", _FIELD_BACKEND_CONTINUATION: "A-cont-5"}, + ], + } + state = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=None) + assert self._queue_bounds(state) == [ + ("00", "40", "B-cont-5"), + ("40", "80", "A-cont-5"), + ] + + # When the loop drains the head slice, the next entry's saved + # backend continuation is naturally exposed as the new head_bc. + state.apply_post_result(items_returned=0, backend_continuation=None) + assert self._bounds(state.head_range) == ("40", "80") + assert state.head_bc == "A-cont-5" + assert self._queue_bounds(state) == [("40", "80", "A-cont-5")] + + def test_apply_post_result_with_continuation_updates_head_bc_in_place(self): + current = _mk_range("00", "40") + next_range = _mk_range("40", "80") + state = _FeedRangePaginationState( + queue=[(current, None), (next_range, None)], + page_size_hint=5, + ) + + state.apply_post_result(items_returned=2, backend_continuation="token-2") + + assert self._queue_bounds(state) == [ + ("00", "40", "token-2"), + ("40", "80", None), + ] + assert state.head_bc == "token-2" + assert state.page_size_hint == 5 + + def test_apply_post_result_with_none_pops_head(self): + current = _mk_range("00", "40") + next_range = _mk_range("40", "80") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (next_range, None)], + page_size_hint=6, + ) + + state.apply_post_result(items_returned=1, backend_continuation=None) + + assert self._queue_bounds(state) == [("40", "80", None)] + assert self._bounds(state.head_range) == ("40", "80") + assert state.head_bc is None + assert state.page_size_hint == 6 + + def test_apply_post_result_with_none_and_no_tail_drains_queue(self): + current = _mk_range("00", "40") + state = _FeedRangePaginationState( + queue=[(current, "token-1")], + page_size_hint=None, + ) + + state.apply_post_result(items_returned=1, backend_continuation=None) + + assert list(state.queue) == [] + assert state.head_range is None + assert state.head_bc is None + + def test_explode_on_multi_overlap_single_overlap_keeps_state(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap([self._pkr("X", "00", "80")]) + + assert did_explode is False + assert self._queue_bounds(state) == [ + ("00", "80", "token-1"), + ("80", "C0", None), + ] + + def test_explode_on_multi_overlap_replaces_head_with_children(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap( + [ + self._pkr("X1", "00", "40"), + self._pkr("X2", "40", "80"), + ] + ) + + assert did_explode is True + # Parent is dequeued and children are appended at the queue tail + # in EPK order (Java/.NET parity). Children inherit the parent's + # backend continuation so resume progress is preserved across split. + assert self._queue_bounds(state) == [ + ("80", "C0", None), + ("00", "40", "token-1"), + ("40", "80", "token-1"), + ] + assert state.head_bc is None + + def test_explode_on_multi_overlap_with_no_parent_continuation_keeps_children_none(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, None), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap( + [ + self._pkr("X1", "00", "40"), + self._pkr("X2", "40", "80"), + ] + ) + + assert did_explode is True + assert self._queue_bounds(state) == [ + ("80", "C0", None), + ("00", "40", None), + ("40", "80", None), + ] + + +class TestCheckpointRoundTripOnException: + """If the per-iteration POST raises mid-page, the call site stamps + the current pagination state into the outbound continuation header + before re-raising. That checkpoint must round-trip so the caller + can retry from exactly where the failed POST left off — never from + the start of the head sub-range, never skipping it. + + These tests drive the state machine + token codec end-to-end + without standing up a live client: emit the checkpoint that the + sync/async loops would emit on exception, then resume from it and + assert the queue is intact. + """ + + @staticmethod + def _bounds(rng: routing_range.Range) -> tuple[str, str]: + return rng.min, rng.max + + def test_checkpoint_emitted_mid_page_resumes_at_same_head(self): + # Simulate: queue is [A (in-flight, bc=cont-A), B, C]; the POST + # for A returned cont-A successfully on a previous iteration but + # the next POST (still on A) is about to raise. The call site + # writes the outbound continuation BEFORE re-raising — that + # token must put us back at (A, cont-A) on resume, with B and C + # still queued behind it untouched. + a, b, c = _mk_range("00", "40"), _mk_range("40", "80"), _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(a, "cont-A"), (b, None), (c, None)], + page_size_hint=7, + ) + + headers: dict = {} + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + # The header carries an opaque, base64-encoded v=1 envelope. + wire = headers[http_constants.HttpHeaders.Continuation] + assert isinstance(wire, str) and wire + + # Caller resumes from that exact wire token. + inbound = _decode_token(wire) + assert inbound is not None + _validate_token_identity(inbound, _RID, _QUERY, _FEED_RANGE) + resumed = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=7) + + # Head is still A with its bc; tail (B, C) is intact and bc-free. + assert self._bounds(resumed.head_range) == ("00", "40") + assert resumed.head_bc == "cont-A" + assert [(r.min, r.max, bc) for r, bc in resumed.queue] == [ + ("00", "40", "cont-A"), + ("40", "80", None), + ("80", "C0", None), + ] + + def test_checkpoint_after_partial_drain_resumes_at_next_head(self): + # Simulate: A was fully drained (apply_post_result with bc=None + # popped it) and the POST for B is about to raise. The + # checkpoint must put us at (B, None) on resume with C still + # queued, and A must NOT reappear. + a, b, c = _mk_range("00", "40"), _mk_range("40", "80"), _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(a, "cont-A-final"), (b, None), (c, None)], + page_size_hint=5, + ) + state.apply_post_result(items_returned=3, backend_continuation=None) + # A is now drained; head should be B. + assert self._bounds(state.head_range) == ("40", "80") + + headers: dict = {} + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + resumed = _FeedRangePaginationState.from_inbound( + _decode_token(headers[http_constants.HttpHeaders.Continuation]), + page_size_hint=2, + ) + assert [(r.min, r.max, bc) for r, bc in resumed.queue] == [ + ("40", "80", None), + ("80", "C0", None), + ] + + def test_drained_state_clears_continuation_header(self): + # When the entire queue is drained, the call site must clear + # the outbound continuation header (not leave a stale one + # behind) so the caller's by_page loop terminates. + headers = {http_constants.HttpHeaders.Continuation: "stale-from-prior-page"} + state = _FeedRangePaginationState(queue=[], page_size_hint=None) + + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + assert http_constants.HttpHeaders.Continuation not in headers + + +class TestWriteQueryOutboundContinuation: + """Outbound continuation format selection should match request scope policy.""" + + def test_full_pk_scope_always_emits_legacy(self): + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + _HEAD_FEEDRANGE, + _BACKEND_CONT, + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=True, + emit_legacy_for_single_partition=False, + ) + + assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT + + def test_non_full_pk_single_partition_scope_emits_legacy(self): + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + _HEAD_FEEDRANGE, + _BACKEND_CONT, + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=False, + emit_legacy_for_single_partition=True, + ) + + assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT + + def test_non_full_pk_multi_partition_scope_emits_structured(self): + state = _FeedRangePaginationState( + [(_HEAD_FEEDRANGE, _BACKEND_CONT), (_REMAINING_FEEDRANGE, None)], + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=False, + emit_legacy_for_single_partition=False, + ) + + decoded = _decode_token(headers[http_constants.HttpHeaders.Continuation]) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + + +class TestLegacyBridgeDecision: + """Legacy inbound continuation is bridged only when scope is safely single-partition.""" + + @pytest.mark.parametrize( + "inbound_serialized_continuation,inbound_token_payload,is_full_pk_structured_scope," + "is_single_partition_scope,expected", + [ + (None, None, False, True, False), + ("", None, False, True, False), + ("legacy", {"v": 1}, False, True, False), + ("legacy", None, True, False, True), + ("legacy", None, False, True, True), + ("legacy", None, False, False, False), + ], + ) + def test_should_bridge_legacy_continuation_policy( + self, + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + is_single_partition_scope, + expected, + ): + assert _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + is_single_partition_scope, + ) is expected + + diff --git a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py index 30e9792b5783..272763f07cc6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py +++ b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py @@ -11,6 +11,26 @@ from azure.cosmos import CosmosClient, PartitionKey +def _assert_full_text_policy_matches(returned, expected): + """Service may augment the returned policy with extra fields (e.g. ``defaultSpec``). + Only assert the fields the caller actually specified are present and equal. + """ + for key, value in expected.items(): + assert returned.get(key) == value, ( + f"fullTextPolicy field {key!r} mismatch: expected {value!r}, got {returned.get(key)!r}" + ) + + +def _assert_message_contains(message, *needles): + """Case-insensitive substring check tolerant of minor service message wording + changes (e.g. ``"Full Text Policy"`` -> ``"full-text policy"``, + ``abstract`` -> ``'abstract'``).""" + haystack = message.lower().replace("'", "").replace("-", " ") + for needle in needles: + n = needle.lower().replace("'", "").replace("-", " ") + assert n in haystack, f"expected {needle!r} in error message, got: {message!r}" + + @pytest.mark.cosmosSearchQuery class TestFullTextPolicy(unittest.TestCase): client: CosmosClient = None @@ -54,8 +74,8 @@ def test_create_full_text_container(self): indexing_policy=indexing_policy ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) self.test_db.delete_container(created_container.id) # Create a container with a full text policy containing only default language @@ -68,7 +88,7 @@ def test_create_full_text_container(self): full_text_policy=full_text_policy_no_paths, ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_paths + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_paths) self.test_db.delete_container(created_container.id) # Create a container with a full text policy with a given path containing only default language @@ -86,7 +106,7 @@ def test_create_full_text_container(self): full_text_policy=full_text_policy_no_langs, ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_langs + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_langs) self.test_db.delete_container(created_container.id) def test_replace_full_text_container(self): @@ -120,7 +140,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) properties = replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['indexingPolicy'] != properties['indexingPolicy'] self.test_db.delete_container(created_container.id) @@ -133,7 +153,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) created_container_properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] # Replace the container with new policies @@ -146,7 +166,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) properties = replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['fullTextPolicy'] != properties['fullTextPolicy'] assert created_container_properties["indexingPolicy"] != properties["indexingPolicy"] @@ -172,7 +192,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an invalid Path: abstract" in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "invalid path", "abstract") # Pass a full text policy with an unsupported default language full_text_policy_wrong_default = { @@ -193,8 +213,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:"\ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") # Pass a full text policy with an unsupported path language full_text_policy_wrong_default = { @@ -215,8 +234,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:"\ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") def test_fail_create_full_text_indexing_policy(self): full_text_policy = { @@ -245,8 +263,9 @@ def test_fail_create_full_text_indexing_policy(self): # pytest.fail("Container creation should have failed for lack of embedding policy.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The path of the Full Text Index /path does not match the path specified in the Full Text Policy"\ - in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "/path", "does not match", "full text policy" + ) # Pass a full text indexing policy with a wrongly formatted path indexing_policy_wrong_path = { @@ -264,7 +283,9 @@ def test_fail_create_full_text_indexing_policy(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Full-text index specification at index (0) contains invalid path" in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "invalid path" + ) # Pass a full text indexing policy without a path field indexing_policy_no_path = { @@ -282,7 +303,9 @@ def test_fail_create_full_text_indexing_policy(self): pytest.fail("Container creation should have failed for missing path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Missing path in full-text index specification at index (0)" in e.http_error_message + _assert_message_contains( + e.http_error_message, "missing path", "full text index" + ) if __name__ == '__main__': diff --git a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py index f52092b102a9..72a081b181ac 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py @@ -13,6 +13,26 @@ from azure.cosmos.aio import CosmosClient +def _assert_full_text_policy_matches(returned, expected): + """Service may augment the returned policy with extra fields (e.g. ``defaultSpec``). + Only assert the fields the caller actually specified are present and equal. + """ + for key, value in expected.items(): + assert returned.get(key) == value, ( + f"fullTextPolicy field {key!r} mismatch: expected {value!r}, got {returned.get(key)!r}" + ) + + +def _assert_message_contains(message, *needles): + """Case-insensitive substring check tolerant of minor service message wording + changes (e.g. ``"Full Text Policy"`` -> ``"full-text policy"``, + ``abstract`` -> ``'abstract'``).""" + haystack = message.lower().replace("'", "").replace("-", " ") + for needle in needles: + n = needle.lower().replace("'", "").replace("-", " ") + assert n in haystack, f"expected {needle!r} in error message, got: {message!r}" + + @pytest.mark.cosmosSearchQuery class TestFullTextPolicyAsync(unittest.IsolatedAsyncioTestCase): host = test_config.TestConfig.host @@ -69,7 +89,7 @@ async def test_create_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] await self.test_db.delete_container(created_container.id) @@ -83,7 +103,7 @@ async def test_create_full_text_container_async(self): full_text_policy=full_text_policy_no_paths, ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_paths + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_paths) await self.test_db.delete_container(created_container.id) # Create a container with a full text policy with a given path containing only default language @@ -101,7 +121,7 @@ async def test_create_full_text_container_async(self): full_text_policy=full_text_policy_no_langs, ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_langs + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_langs) async def test_replace_full_text_container_async(self): # Replace a container without a full text policy and full text indexing policy @@ -134,7 +154,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['indexingPolicy'] != properties['indexingPolicy'] await self.test_db.delete_container(created_container.id) @@ -147,7 +167,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) created_container_properties = await created_container.read() - assert created_container_properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(created_container_properties["fullTextPolicy"], full_text_policy) assert created_container_properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] # Replace the container with new policies @@ -160,7 +180,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['fullTextPolicy'] != properties['fullTextPolicy'] assert created_container_properties["indexingPolicy"] != properties["indexingPolicy"] @@ -186,7 +206,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an invalid Path: abstract" in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "invalid path", "abstract") # Pass a full text policy with an unsupported default language full_text_policy_wrong_default = { @@ -207,8 +227,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:" \ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") # Pass a full text policy with an unsupported path language full_text_policy_wrong_default = { @@ -229,8 +248,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:" \ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") async def test_fail_create_full_text_indexing_policy_async(self): full_text_policy = { @@ -259,8 +277,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): # pytest.fail("Container creation should have failed for lack of embedding policy.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The path of the Full Text Index /path does not match the path specified in the Full Text Policy" \ - in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "/path", "does not match", "full text policy" + ) # Pass a full text indexing policy with a wrongly formatted path indexing_policy_wrong_path = { @@ -278,7 +297,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Full-text index specification at index (0) contains invalid path" in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "invalid path" + ) # Pass a full text indexing policy without a path field indexing_policy_no_path = { @@ -296,7 +317,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): pytest.fail("Container creation should have failed for missing path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Missing path in full-text index specification at index (0)" in e.http_error_message + _assert_message_contains( + e.http_error_message, "missing path", "full text index" + ) if __name__ == '__main__': diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index 090717519222..84dce71d567d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -472,5 +472,88 @@ def test_write_fallback_to_global_after_regional_retries_exhausted(self): final_endpoint = lc.resolve_service_endpoint(write_request) assert final_endpoint == location1_endpoint + def test_unavailable_endpoints_not_dropped_from_routing_list(self): + """ + Unavailable endpoints should be appended to the end of the routing list, + not dropped entirely. + + Scenario: + - Customer has preferred_locations = ["East US", "West US 2"] + - East US is marked unavailable for writes + - Customer makes a request with excluded_locations = ["West US 2"] + - Expected: East US should still be available as fallback (unavailable but in the list) + """ + # Setup: Two preferred locations, multi-write enabled + preferred_locations = [location1_name, location2_name] + lc = refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + db_acc = create_database_account(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Verify initial state: Both locations are in write_regional_routing_contexts + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 2 + assert write_contexts[0].get_primary() == location1_endpoint + assert write_contexts[1].get_primary() == location2_endpoint + + # Mark location1 (East US) as unavailable for writes + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + # After marking unavailable, the routing list should still contain + # both endpoints - healthy ones first, unavailable ones at the end + write_contexts_after = lc.get_write_regional_routing_contexts() + assert len(write_contexts_after) == 2, \ + f"Expected 2 endpoints in routing list, got {len(write_contexts_after)}. " \ + "Unavailable endpoint was incorrectly dropped!" + # location2 (healthy) should be first + assert write_contexts_after[0].get_primary() == location2_endpoint + # location1 (unavailable) should be at the end as fallback + assert write_contexts_after[1].get_primary() == location1_endpoint + + # Now simulate the customer request with excluded_locations = ["location2"] + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = [location2_name] + + # Resolve endpoint - should get location1 (unavailable) as the only remaining option + # NOT the global default endpoint! + resolved_endpoint = lc.resolve_service_endpoint(write_request) + + # Should fall back to location1 (unavailable regional endpoint) + # NOT the global endpoint + assert resolved_endpoint == location1_endpoint, \ + f"Expected {location1_endpoint} but got {resolved_endpoint}. " \ + f"Bug: Unavailable endpoint was dropped and SDK fell back to global endpoint!" + + def test_unavailable_endpoints_ordering_in_routing_list(self): + """ + Test that healthy endpoints come before unavailable endpoints in the routing list. + This ensures the SDK tries healthy regions first, but has unavailable ones as fallback. + """ + # Setup: Three preferred locations + preferred_locations = [location1_name, location2_name, location3_name] + lc = refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + db_acc = create_database_account(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Mark location1 as unavailable + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + # Check ordering: location2, location3 (healthy) should come before location1 (unavailable) + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 3 + assert write_contexts[0].get_primary() == location2_endpoint # First healthy + assert write_contexts[1].get_primary() == location3_endpoint # Second healthy + assert write_contexts[2].get_primary() == location1_endpoint # Unavailable at end + + # Mark location2 as unavailable too + lc.mark_endpoint_unavailable_for_write(location2_endpoint, refresh_cache=True, context="test") + + # Check ordering: location3 (healthy) should come before location1, location2 (unavailable) + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 3 + assert write_contexts[0].get_primary() == location3_endpoint # Only healthy + # Unavailable ones at end, in original preferred order + assert write_contexts[1].get_primary() == location1_endpoint + assert write_contexts[2].get_primary() == location2_endpoint + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index 5564136002b7..194a6cbea2ec 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -8,13 +8,16 @@ import gc import time import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from azure.cosmos import exceptions +from azure.cosmos._cosmos_client_connection import CosmosClientConnection from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders + # tracemalloc is not available in PyPy, so we import conditionally try: @@ -34,10 +37,17 @@ def is_circuit_breaker_applicable(self, request): return False +class MockRoutingMapProvider: + """Mock routing map provider with a collection routing map cache.""" + def __init__(self): + self._collection_routing_map_by_item = {} + + class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): self._global_endpoint_manager = MockGlobalEndpointManager() + self.last_response_headers = {} self.refresh_routing_map_provider_call_count = 0 def refresh_routing_map_provider(self): @@ -73,6 +83,98 @@ class TestPartitionSplitRetryUnit(unittest.TestCase): Sync unit tests for 410 partition split retry logic. """ + @staticmethod + def _create_minimal_connection() -> CosmosClientConnection: + client = CosmosClientConnection.__new__(CosmosClientConnection) + client.default_headers = {} + client.last_response_headers = {} + client._UpdateSessionIfRequired = lambda *args, **kwargs: None + return client + + def test_queryfeed_internal_capture_uses_options_dict(self): + """QueryFeed should honor _internal_response_headers_capture from options.""" + client = self._create_minimal_connection() + captured_headers = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc1"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": captured_headers}, + ) + + self.assertEqual(docs, [{"id": "doc1"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(captured_headers, expected_headers) + + def test_queryfeed_internal_capture_falls_back_to_kwargs(self): + """QueryFeed should still support kwargs-based internal capture for compatibility.""" + client = self._create_minimal_connection() + kwargs_capture = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-kwargs", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc2"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc2"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(kwargs_capture, expected_headers) + + def test_queryfeed_internal_capture_both_present_populates_one(self): + """When both options- and kwargs-based capture dicts are present + (a configuration that does not occur in production — the two + upstream paths are mutually exclusive by design), QueryFeed must + populate exactly one of the two capture dicts with the response + headers. Precedence is intentionally unspecified. + """ + client = self._create_minimal_connection() + options_capture: dict = {} + kwargs_capture: dict = {} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc3"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": options_capture}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc3"}]) + self.assertEqual(response_headers, expected_headers) + populated = [d for d in (options_capture, kwargs_capture) if d == expected_headers] + self.assertEqual( + len(populated), 1, + f"expected exactly one capture dict populated; got options={options_capture!r}, kwargs={kwargs_capture!r}", + ) + def test_execution_context_state_reset_on_partition_split(self): """ Test that execution context state is properly reset on 410 partition split retry. @@ -152,6 +254,221 @@ def mock_fetch_function(options): "refresh_routing_map_provider should be called once on 410" assert result == expected_docs, "Should return expected documents after retry" + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_uses_checkpoint_continuation_from_internal_capture(self, mock_execute): + """410 retry should resume from checkpoint continuation stamped by __QueryFeed.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["checkpoint-token"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_uses_queryfeed_captured_checkpoint_end_to_end(self, mock_execute): + """End-to-end: QueryFeed stamps capture dict, 410 occurs, retry resumes from checkpoint token.""" + mock_client = MockClient() + query_client = self._create_minimal_connection() + query_client._query_compatibility_mode = query_client._QueryCompatibilityMode.Default + + context = None + seen_continuations = [] + execute_call_count = [0] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + continuation = req_headers.get(HttpHeaders.Continuation) + if continuation: + return ({"Documents": [{"id": "resumed"}]}, {}) + return ({"Documents": [{"id": "checkpoint-page"}]}, {HttpHeaders.Continuation: "checkpoint-token"}) + + def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + execute_call_count[0] += 1 + if execute_call_count[0] == 1: + callback() + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def fetch_function(options): + seen_continuations.append(options.get("continuation")) + docs, headers = query_client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options=options, + ) + return docs, headers + + def mock_get_headers(*args, **kwargs): + options = args[7] if len(args) > 7 else kwargs.get("options", {}) + headers = {} + if options and options.get("continuation") is not None: + headers[HttpHeaders.Continuation] = options.get("continuation") + return headers + + context = _DefaultQueryExecutionContext(mock_client, {}, fetch_function) + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', side_effect=mock_get_headers): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object(query_client, '_CosmosClientConnection__Post', side_effect=post_side_effect): + result = context._fetch_items_helper_with_retries(fetch_function) + + assert execute_call_count[0] == 2 + assert seen_continuations == [None, "checkpoint-token"] + assert result == [{"id": "resumed"}] + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_ignores_stale_shared_client_headers(self, mock_execute): + """Retry resumes from request-local captured headers, not shared client headers.""" + mock_client = MockClient() + mock_client.last_response_headers = {HttpHeaders.Continuation: "stale-global-token"} + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "fresh-checkpoint" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["fresh-checkpoint"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_without_checkpoint_continuation_retries_from_none(self, mock_execute): + """If no checkpoint header is stamped, continuation should remain None on retry.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture.clear() + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == [None] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_multiple_410_uses_latest_checkpoint_continuation(self, mock_execute): + """Across repeated 410 retries, execution should resume using the latest checkpoint token.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-1" + raise create_410_partition_split_error() + if call_count[0] == 2: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-2" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 3 + assert seen_continuations == ["checkpoint-token-2"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_mid_pagination_split_retries_from_checkpoint_without_duplicates(self, mock_execute): + """Simulate page2 split and verify retry resumes from checkpoint token, not from page1.""" + mock_client = MockClient() + + docs_page_1 = [{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}, {"id": "5"}] + docs_page_2 = [{"id": "6"}, {"id": "7"}, {"id": "8"}, {"id": "9"}, {"id": "10"}] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + return callback() + + mock_execute.side_effect = execute_side_effect + + fetch_calls = [] + + def mock_fetch_function(options): + continuation = options.get("continuation") + fetch_calls.append(continuation) + + if continuation is None: + return (docs_page_1, {HttpHeaders.Continuation: "token-after-page-1"}) + + if continuation == "token-after-page-1": + # Simulate __QueryFeed writing a checkpoint before re-raising split error. + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-after-split" + raise create_410_partition_split_error() + + if continuation == "checkpoint-after-split": + return (docs_page_2, {}) + + self.fail(f"Unexpected continuation seen by fetch: {continuation}") + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + first_result = context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(first_result, docs_page_1) + + second_result = context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(second_result, docs_page_2) + + # Validate the second page did not replay page-1 items and resumed from checkpoint. + self.assertEqual(fetch_calls, [None, "token-after-page-1", "checkpoint-after-split"]) + @patch('azure.cosmos._retry_utility.Execute') def test_pk_range_query_skips_410_retry_to_prevent_recursion(self, mock_execute): """ @@ -323,6 +640,238 @@ def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + def test_queryfeed_populates_capture_dict_from_options(self): + """`__QueryFeed` must read the capture dict from `options` and + populate it from the underlying response headers. + + This is the producer-side counterpart to the checkpoint tests + above: it does not inject into the capture dict, it asserts that + `__QueryFeed` itself does the population. Catches the + `options`-vs-`kwargs` extraction regression. + """ + from unittest.mock import patch as _patch + + # Build a CosmosClientConnection without running __init__; we + # only need the attributes that the no-query (read-feed) branch + # of __QueryFeed touches. + conn = object.__new__(CosmosClientConnection) + conn.default_headers = {} + conn.last_response_headers = {} + conn.availability_strategy = None + conn.availability_strategy_executor = None + conn._global_endpoint_manager = MockGlobalEndpointManager() + conn._routing_map_provider = MockRoutingMapProvider() + conn.session = None + conn.connection_policy = MagicMock() + + capture_dict = {} + options = { + "_internal_response_headers_capture": capture_dict, + } + + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-queryfeed"} + + request_obj_mock = MagicMock( + set_excluded_location_from_options=MagicMock(), + set_availability_strategy=MagicMock(), + headers={}, + ) + + # Patch the heavy collaborators inside __QueryFeed's no-query + # branch so we can drive it without a real pipeline. + with _patch( + "azure.cosmos._cosmos_client_connection.base.GetHeaders", + return_value={}, + ), \ + _patch( + "azure.cosmos._cosmos_client_connection.base.set_session_token_header" + ), \ + _patch( + "azure.cosmos._cosmos_client_connection.RequestObject", + return_value=request_obj_mock, + ) as request_obj_ctor, \ + _patch.object( + CosmosClientConnection, + "_CosmosClientConnection__Get", + return_value=( + {"Documents": [{"id": "1"}], "_count": 1}, + canned_headers, + ), + ) as mock_get: + + _ = request_obj_ctor # silence unused-warning + + # Invoke the name-mangled private method directly. + result, headers = conn._CosmosClientConnection__QueryFeed( + "/dbs/db/colls/c/docs", + "docs", + "rid1", + lambda r: r["Documents"], + lambda _c, b: b, + None, # query=None -> read-feed branch -> __Get + options, + None, # partition_key_range_id + ) + + assert mock_get.called, "expected __Get to be invoked on the no-query path" + + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-queryfeed", ( + f"capture dict was not populated by __QueryFeed; got {capture_dict!r}. " + "This indicates __QueryFeed is not reading " + "'_internal_response_headers_capture' from options." + ) + + # the marker key must have been removed from options so it + # never leaks downstream into header construction or RequestObject. + assert "_internal_response_headers_capture" not in options, ( + "__QueryFeed should pop the capture marker out of options" + ) + + # Sanity check on the result tuple shape. + assert result == [{"id": "1"}] + assert headers is canned_headers + + def test_queryfeed_feed_range_legacy_inbound_single_partition_honors_and_emits_legacy(self): + """Legacy inbound continuation is honored when feed_range currently maps to one partition.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + + def overlap_side_effect(_rid, ranges, _opts): + _ = ranges + return single_overlap + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + seen_request_continuations = [] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "legacy-next-token"} + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == ["legacy-inbound-token"] + assert headers.get(HttpHeaders.Continuation) == "legacy-next-token" + + def test_queryfeed_feed_range_legacy_inbound_multi_partition_restarts_and_emits_v1(self): + """Legacy inbound continuation is ignored when scope is multi-partition; outbound becomes v=1.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "0", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "1", "minInclusive": "7F", "maxExclusive": "FF"} + + def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + seen_request_continuations = [] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "child-legacy-token"} + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == [None] + outbound = headers.get(HttpHeaders.Continuation) + decoded = _decode_token(outbound) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + + def test_queryfeed_feed_range_routing_lookup_failure_stamps_checkpoint(self): + """A failure inside the mid-page routing-map lookup must stamp a resumable + checkpoint into ``last_response_headers[Continuation]`` before re-raising, + not just failures from the backend POST. + """ + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + routing_call_count = {"n": 0} + + def overlap_side_effect(_rid, _ranges, _opts): + routing_call_count["n"] += 1 + # First call (legacy bridge classification) succeeds; the mid-page + # iteration call fails so we exercise the widened try block. + if routing_call_count["n"] >= 2: + raise RuntimeError("routing-map-down") + return single_overlap + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post") as post_mock: + with pytest.raises(RuntimeError, match="routing-map-down"): + client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + post_mock.assert_not_called() + + # Checkpoint must be present so the caller can resume on retry. + # Single-partition scope => legacy-format checkpoint (the original inbound token). + continuation = client.last_response_headers.get(HttpHeaders.Continuation) + assert continuation == "legacy-inbound-token" + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index a487fc6a85eb..8642d37b87de 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -8,14 +8,16 @@ import gc import time import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from azure.cosmos import exceptions -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext +from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token # tracemalloc is not available in PyPy, so we import conditionally try: @@ -35,10 +37,17 @@ def is_circuit_breaker_applicable(self, request): return False +class MockRoutingMapProvider: + """Mock routing map provider with a collection routing map cache.""" + def __init__(self): + self._collection_routing_map_by_item = {} + + class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): self._global_endpoint_manager = MockGlobalEndpointManager() + self.last_response_headers = {} self.refresh_routing_map_provider_call_count = 0 def refresh_routing_map_provider(self): @@ -76,6 +85,107 @@ class TestPartitionSplitRetryUnitAsync(unittest.IsolatedAsyncioTestCase): Async unit tests for 410 partition split retry logic. """ + @staticmethod + def _create_minimal_connection() -> CosmosClientConnection: + client = CosmosClientConnection.__new__(CosmosClientConnection) + client.default_headers = {} + client.last_response_headers = {} + client._UpdateSessionIfRequired = lambda *args, **kwargs: None + return client + + async def test_queryfeed_internal_capture_uses_options_dict_async(self): + """Async QueryFeed should honor _internal_response_headers_capture from options.""" + client = self._create_minimal_connection() + captured_headers = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc1"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": captured_headers}, + ) + + self.assertEqual(docs, [{"id": "doc1"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(captured_headers, expected_headers) + self.assertEqual(client.last_response_headers, expected_headers) + + async def test_queryfeed_internal_capture_falls_back_to_kwargs_async(self): + """Async QueryFeed should still support kwargs-based internal capture for compatibility.""" + client = self._create_minimal_connection() + kwargs_capture = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-kwargs", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc2"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc2"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(kwargs_capture, expected_headers) + self.assertEqual(client.last_response_headers, expected_headers) + + async def test_queryfeed_internal_capture_both_present_populates_one_async(self): + """When both options- and kwargs-based capture dicts are present + (a configuration that does not occur in production — the two + upstream paths are mutually exclusive by design), async QueryFeed + must populate exactly one of the two capture dicts with the + response headers. Precedence is intentionally unspecified. + """ + client = self._create_minimal_connection() + options_capture: dict = {} + kwargs_capture: dict = {} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc3"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": options_capture}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc3"}]) + self.assertEqual(response_headers, expected_headers) + populated = [d for d in (options_capture, kwargs_capture) if d == expected_headers] + self.assertEqual( + len(populated), 1, + f"expected exactly one capture dict populated; got options={options_capture!r}, kwargs={kwargs_capture!r}", + ) + self.assertEqual(client.last_response_headers, expected_headers) + async def test_execution_context_state_reset_on_partition_split_async(self): """ Test that execution context state is properly reset on 410 partition split retry (async). @@ -151,6 +261,223 @@ async def mock_fetch_function(options): "refresh_routing_map_provider should be called once on 410" assert result == expected_docs, "Should return expected documents after retry" + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_uses_checkpoint_continuation_from_internal_capture_async(self, mock_execute): + """410 retry should resume from checkpoint continuation stamped by __QueryFeed (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["checkpoint-token"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_uses_queryfeed_captured_checkpoint_end_to_end_async(self, mock_execute): + """End-to-end async: QueryFeed stamps capture dict, 410 occurs, retry resumes from checkpoint token.""" + mock_client = MockClient() + query_client = self._create_minimal_connection() + query_client._query_compatibility_mode = query_client._QueryCompatibilityMode.Default + + seen_continuations = [] + execute_call_count = [0] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + continuation = req_headers.get(HttpHeaders.Continuation) + if continuation: + return {"Documents": [{"id": "resumed"}]}, {} + return {"Documents": [{"id": "checkpoint-page"}]}, {HttpHeaders.Continuation: "checkpoint-token"} + + async def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + execute_call_count[0] += 1 + if execute_call_count[0] == 1: + await callback() + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def _noop_set_session(*args, **kwargs): + return None + + async def fetch_function(options): + seen_continuations.append(options.get("continuation")) + docs, headers = await query_client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options=options, + ) + return docs, headers + + def mock_get_headers(*args, **kwargs): + options = args[7] if len(args) > 7 else kwargs.get("options", {}) + headers = {} + if options and options.get("continuation") is not None: + headers[HttpHeaders.Continuation] = options.get("continuation") + return headers + + context = _DefaultQueryExecutionContext(mock_client, {}, fetch_function) + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', side_effect=mock_get_headers): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(query_client, '_CosmosClientConnection__Post', side_effect=post_side_effect): + result = await context._fetch_items_helper_with_retries(fetch_function) + + assert execute_call_count[0] == 2 + assert seen_continuations == [None, "checkpoint-token"] + assert result == [{"id": "resumed"}] + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_ignores_stale_shared_client_headers_async(self, mock_execute): + """Retry resumes from request-local captured headers, not shared client headers.""" + mock_client = MockClient() + mock_client.last_response_headers = {HttpHeaders.Continuation: "stale-global-token"} + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + async def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "fresh-checkpoint" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["fresh-checkpoint"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_without_checkpoint_continuation_retries_from_none_async(self, mock_execute): + """If no checkpoint header is stamped, continuation should remain None on retry (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture.clear() + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == [None] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_multiple_410_uses_latest_checkpoint_continuation_async(self, mock_execute): + """Across repeated 410 retries, execution should resume using the latest checkpoint token (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-1" + raise create_410_partition_split_error() + if call_count[0] == 2: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-2" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 3 + assert seen_continuations == ["checkpoint-token-2"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_mid_pagination_split_retries_from_checkpoint_without_duplicates_async(self, mock_execute): + """Simulate page2 split and verify async retry resumes from checkpoint token, not from page1.""" + mock_client = MockClient() + + docs_page_1 = [{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}, {"id": "5"}] + docs_page_2 = [{"id": "6"}, {"id": "7"}, {"id": "8"}, {"id": "9"}, {"id": "10"}] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + return await callback() + + mock_execute.side_effect = execute_side_effect + + fetch_calls = [] + + async def mock_fetch_function(options): + continuation = options.get("continuation") + fetch_calls.append(continuation) + + if continuation is None: + return (docs_page_1, {HttpHeaders.Continuation: "token-after-page-1"}) + + if continuation == "token-after-page-1": + # Simulate __QueryFeed writing a checkpoint before re-raising split error. + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-after-split" + raise create_410_partition_split_error() + + if continuation == "checkpoint-after-split": + return (docs_page_2, {}) + + self.fail(f"Unexpected continuation seen by fetch: {continuation}") + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + first_result = await context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(first_result, docs_page_1) + + second_result = await context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(second_result, docs_page_2) + + # Validate the second page did not replay page-1 items and resumed from checkpoint. + self.assertEqual(fetch_calls, [None, "token-after-page-1", "checkpoint-after-split"]) + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') async def test_pk_range_query_skips_410_retry_to_prevent_recursion_async(self, mock_execute): """ @@ -311,6 +638,200 @@ async def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + async def test_queryfeed_populates_capture_dict_from_options_async(self): + """Async `__QueryFeed` must read the capture dict from `options` + and populate it from the underlying response headers, with no + test-side injection. Catches the `options`-vs-`kwargs` + extraction regression on the async path. + """ + from unittest.mock import patch as _patch + + # Build a CosmosClientConnection without running __init__; we + # only need the attributes that the no-query (read-feed) branch + # of async __QueryFeed touches. + conn = object.__new__(CosmosClientConnection) + conn.default_headers = {} + conn.last_response_headers = {} + conn.availability_strategy = None + conn.availability_strategy_max_concurrency = None + conn._global_endpoint_manager = MockGlobalEndpointManager() + conn._routing_map_provider = MagicMock(_collection_routing_map_by_item={}) + conn.session = None + conn.connection_policy = MagicMock() + conn._UpdateSessionIfRequired = MagicMock() + + capture_dict = {} + options = { + "_internal_response_headers_capture": capture_dict, + } + + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-queryfeed-async"} + + request_obj_mock = MagicMock( + set_excluded_location_from_options=MagicMock(), + set_availability_strategy=MagicMock(), + headers={}, + operation_type="ReadFeed", + ) + + # Patch the heavy collaborators inside async __QueryFeed's + # no-query branch so we can drive it without a real pipeline. + with _patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", + return_value={}, + ), \ + _patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + new=AsyncMock(), + ), \ + _patch( + "azure.cosmos.aio._cosmos_client_connection_async._request_object.RequestObject", + return_value=request_obj_mock, + ), \ + _patch.object( + CosmosClientConnection, + "_CosmosClientConnection__Get", + new=AsyncMock(return_value=( + {"Documents": [{"id": "1"}], "_count": 1}, + canned_headers, + )), + ) as mock_get: + + # Invoke the name-mangled private async method directly. + result = await conn._CosmosClientConnection__QueryFeed( + "/dbs/db/colls/c/docs", + "docs", + "rid1", + lambda r: r["Documents"], + lambda _c, b: b, + None, # query=None -> read-feed branch -> __Get + options, + None, # partition_key_range_id + ) + + assert mock_get.await_count >= 1, "expected __Get to be awaited on the no-query path" + + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-queryfeed-async", ( + f"capture dict was not populated by async __QueryFeed; got {capture_dict!r}. " + "This indicates async __QueryFeed is not reading " + "'_internal_response_headers_capture' from options." + ) + + # And the marker key must have been removed from options so it + # never leaks downstream into header construction or RequestObject. + assert "_internal_response_headers_capture" not in options, ( + "async __QueryFeed should pop the capture marker out of options" + ) + + # Sanity check: async no-query branch returns just the body list. + assert result == [{"id": "1"}] + + async def test_queryfeed_feed_range_legacy_inbound_single_partition_honors_and_emits_legacy_async(self): + """Async: legacy inbound continuation is honored when feed_range maps to one partition.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + + async def overlap_side_effect(_rid, ranges, _opts): + _ = ranges + return single_overlap + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + seen_request_continuations = [] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "legacy-next-token"} + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == ["legacy-inbound-token"] + assert headers.get(HttpHeaders.Continuation) == "legacy-next-token" + + async def test_queryfeed_feed_range_legacy_inbound_multi_partition_restarts_and_emits_v1_async(self): + """Async: legacy inbound continuation is ignored when scope is multi-partition; outbound becomes v=1.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "0", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "1", "minInclusive": "7F", "maxExclusive": "FF"} + + async def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + seen_request_continuations = [] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "child-legacy-token"} + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == [None] + outbound = headers.get(HttpHeaders.Continuation) + decoded = _decode_token(outbound) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 4970723e7757..84ac3116efab 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -12,6 +12,7 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import http_constants, DatabaseProxy, _endpoint_discovery_retry_policy +from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType @@ -38,8 +39,7 @@ def setUpClass(cls): use_multiple_write_locations = True cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) - if cls.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" + def test_first_and_last_slashes_trimmed_for_query_string(self): created_collection = self.created_db.create_container( @@ -105,12 +105,20 @@ def test_populate_index_metrics(self): self.assertTrue(INDEX_HEADER_NAME in created_collection.client_connection.last_response_headers) index_metrics = created_collection.client_connection.last_response_headers[INDEX_HEADER_NAME] self.assertIsNotNone(index_metrics) - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - self.assertDictEqual(expected_index_metrics, index_metrics) + self.assertIn('UtilizedSingleIndexes', index_metrics) + self.assertIn('PotentialSingleIndexes', index_metrics) + self.assertIn('UtilizedCompositeIndexes', index_metrics) + self.assertIn('PotentialCompositeIndexes', index_metrics) + + # Backend index diagnostics can vary by region/build; validate a stable shape and key signal. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + self.assertTrue(any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + )) self.created_db.delete_container(created_collection.id) # TODO: Need to validate the query request count logic @@ -469,6 +477,46 @@ def test_cross_partition_query_with_continuation_token(self): self.assertEqual(second_page['id'], second_page_fetched_with_continuation_token['id']) + def test_full_pk_continuation_emits_legacy_by_default(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + + self.assertIsNotNone(token) + self.assertIsNone(_decode_token(token)) + + def test_full_pk_legacy_replay_resumes_same_page(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + second_page = list(pager.next())[0] + + self.assertIsNotNone(token) + self.assertIsNone(_decode_token(token)) + + replay_pager = query_iterable.by_page(token) + replay_second_page = list(replay_pager.next())[0] + self.assertEqual(second_page['id'], replay_second_page['id']) + + def test_cross_partition_query_with_none_partition_key(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) document_definition = {'pk': 'pk1', 'id': str(uuid.uuid4())} diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 40207abddf14..4ec4af5c1406 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -12,6 +12,7 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import http_constants, _endpoint_discovery_retry_policy +from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos._retry_options import RetryOptions from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy @@ -49,12 +50,11 @@ async def asyncSetUp(self): self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=self.use_multiple_write_locations) await self.client.__aenter__() self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) - if self.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" async def asyncTearDown(self): await self.client.close() + async def test_first_and_last_slashes_trimmed_for_query_string_async(self): created_collection = await self.created_db.create_container( str(uuid.uuid4()), PartitionKey(path="/pk")) @@ -126,12 +126,20 @@ async def test_populate_index_metrics_async(self): assert index_header_name in created_collection.client_connection.last_response_headers index_metrics = created_collection.client_connection.last_response_headers[index_header_name] assert index_metrics != {} - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - assert expected_index_metrics == index_metrics + assert 'UtilizedSingleIndexes' in index_metrics + assert 'PotentialSingleIndexes' in index_metrics + assert 'UtilizedCompositeIndexes' in index_metrics + assert 'PotentialCompositeIndexes' in index_metrics + + # Backend index diagnostics can vary by region/build; validate stable signal instead of exact payload. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + assert any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + ) await self.created_db.delete_container(created_collection.id) @@ -465,6 +473,49 @@ async def test_cross_partition_query_with_continuation_token_async(self): assert second_page['id'] == second_page_fetched_with_continuation_token['id'] + async def test_full_pk_continuation_emits_legacy_by_default_async(self): + """Full partition-key queries return legacy continuation tokens by default.""" + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + + assert token is not None + assert _decode_token(token) is None + + + async def test_full_pk_legacy_replay_resumes_same_page_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + second_page = [item async for item in await pager.__anext__()][0] + + assert token is not None + assert _decode_token(token) is None + + replay_pager = query_iterable.by_page(token) + replay_second_page = [item async for item in await replay_pager.__anext__()][0] + assert second_page['id'] == replay_second_page['id'] + + + async def test_cross_partition_query_with_none_partition_key_async(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) document_definition = {'pk': 'pk1', 'id': str(uuid.uuid4())} diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py index b1b4f050ab63..6daaf2304ad6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py @@ -44,8 +44,6 @@ def setUpClass(cls): use_multiple_write_locations = True cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) - if cls.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" def setUp(self): self.created_container = self.created_db.create_container( @@ -220,12 +218,20 @@ def test_populate_index_metrics(self): self.assertTrue(INDEX_HEADER_NAME in self.created_container.client_connection.last_response_headers) index_metrics = self.created_container.client_connection.last_response_headers[INDEX_HEADER_NAME] self.assertIsNotNone(index_metrics) - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - self.assertDictEqual(expected_index_metrics, index_metrics) + self.assertIn('UtilizedSingleIndexes', index_metrics) + self.assertIn('PotentialSingleIndexes', index_metrics) + self.assertIn('UtilizedCompositeIndexes', index_metrics) + self.assertIn('PotentialCompositeIndexes', index_metrics) + + # Backend index diagnostics can vary by region/build; validate stable signal instead of exact payload. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + self.assertTrue(any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + )) def test_get_query_plan_through_gateway(self): self._validate_query_plan(query="Select top 10 value count(c.id) from c", diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py new file mode 100644 index 000000000000..8af9f0517310 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py @@ -0,0 +1,1005 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""End-to-end tests for ``query_items(feed_range=...)`` against a feed_range +that overlaps more than one physical partition. + +These tests pin two invariants of the multi-overlap pagination contract: + +* every page returns at most ``max_item_count`` items, and +* no item id is returned on more than one page (no duplicates across the + fan-out / resume boundary). + +Three scenarios are covered: + +* ``test_two_partition_feed_range`` — feed_range overlaps two adjacent + physical partitions. +* ``test_three_way_overlap`` — feed_range overlaps three adjacent physical + partitions; wider fan-out. +* ``test_post_split_resume`` — emit a continuation under one physical + layout, force a real partition split, then resume with the same + continuation against the new layout. Slow (``cosmosSplit``). + + +Async parity lives in ``test_query_feed_range_multipartition_async.py``. +""" + +import time +import unittest +import uuid +from typing import Iterable, List, Optional, Tuple + +import pytest + +import test_config +from azure.cosmos import _base +from azure.cosmos import CosmosClient, documents, http_constants +from azure.cosmos._routing.feed_range_continuation import _decode_token +from azure.cosmos.partition_key import PartitionKey + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID + +# Dedicated container for these tests. Throughput is chosen so the +# container has multiple physical partitions out of the box (no split needed +# for the two/three-overlap tests); the post-split test then forces an +# additional split on top. +REPRO_CONTAINER_ID = "FeedRangeMultiPartition-" + str(uuid.uuid4()) +REPRO_PARTITION_KEY = "pk" +REPRO_THROUGHPUT = CONFIG.THROUGHPUT_FOR_5_PARTITIONS # 30000 → ~5 partitions +REPRO_DOC_COUNT = 200 # spread across partitions; ensures every partition has data + +# Per-page cap applied to every multi-overlap query in this module. +# Small enough to drive several pages per partition under the seeded data +# count, so any per-page over-fetch or duplicate-on-resume shows up across +# the page sequence rather than only on the last page. +PAGE_SIZE = 5 + +# Per-overlap data threshold below which we skip a configuration as not a +# meaningful repro. Need enough docs in each partition to drive ≥ 3 pages +# under PAGE_SIZE = 5. +MIN_DOCS_PER_PARTITION = 15 + + +def _client() -> CosmosClient: + return CosmosClient(HOST, KEY) + + +def _get_container(): + db = _client().get_database_client(DATABASE_ID) + return db.get_container_client(REPRO_CONTAINER_ID) + + +def _sorted_partition_ranges(container) -> List[Tuple[str, str]]: + """Return current physical partitions' EPK ranges as (min, max) tuples, + sorted by ``min``. Reads the routing map via ``read_feed_ranges()`` (the + public surface that returns one dict per current physical partition). + """ + feed_ranges = list(container.read_feed_ranges()) + pairs: List[Tuple[str, str]] = [] + for fr in feed_ranges: + r = fr["Range"] + pairs.append((r["min"], r["max"])) + pairs.sort(key=lambda p: p[0]) + return pairs + + +def _count_in_range(container, range_min: str, range_max: str) -> int: + fr = test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + items = list(container.query_items( + query="SELECT VALUE COUNT(1) FROM c", feed_range=fr)) + return items[0] if items else 0 + + +def _crossing_feed_range(range_min: str, range_max: str): + """Synthesize a feed_range whose ``[min, max)`` interval spans the union + of one or more current physical partitions — the shape a feed_range + takes after the underlying partition has been split.""" + return test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + + +def _ids_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + """Ground-truth set of document ids inside the union of the given physical + partition ranges. Each partition is queried independently (each call is a + single-overlap query that does NOT exercise the multi-overlap fan-out + branch), so this is the correct baseline to compare a crossing-feed_range + query against.""" + ground_truth = set() + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + for item in container.query_items(query="SELECT c.id FROM c", feed_range=fr): + ground_truth.add(item["id"]) + return ground_truth + + +def _drain_pages(pager) -> Tuple[List[List[dict]], List[str]]: + """Iterate ``pager`` to exhaustion. Return the per-page item lists (so the + caller can assert on per-page sizes) and the ordered list of all ids + encountered (so the caller can assert on duplicates).""" + pages: List[List[dict]] = [] + all_ids: List[str] = [] + for page in pager: + items = list(page) + pages.append(items) + all_ids.extend(item["id"] for item in items) + return pages, all_ids + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(): + """Create a dedicated container for these tests, populate it with enough + documents that every physical partition has data on both sides of every + internal boundary, and tear it down afterwards.""" + client = _client() + db = client.get_database_client(DATABASE_ID) + container = db.create_container_if_not_exists( + id=REPRO_CONTAINER_ID, + partition_key=PartitionKey(path="/" + REPRO_PARTITION_KEY, kind="Hash"), + offer_throughput=REPRO_THROUGHPUT) + # Insert REPRO_DOC_COUNT documents with distinct partition-key values. + # SHA-based PK hashing distributes these roughly uniformly across the + # container's physical partitions, so each partition ends up with a few + # dozen documents — enough to drive multiple pages at PAGE_SIZE=5. + for i in range(REPRO_DOC_COUNT): + container.upsert_item({ + REPRO_PARTITION_KEY: f"pk-{i:04d}", + "id": f"doc-{i:04d}", + "value": i, + }) + yield + try: + db.delete_container(REPRO_CONTAINER_ID) + except Exception: # pylint: disable=broad-except + pass + + +@pytest.mark.cosmosQuery +class TestFeedRangeMultiPartition: + """Sync end-to-end tests for feed_range queries that overlap multiple + physical partitions.""" + + # ------------------------------------------------------------------ # + # Single-partition control (regression guard for the no-fan-out path) + # ------------------------------------------------------------------ # + def test_single_partition_feed_range(self): + """``feed_range`` strictly inside one physical partition's EPK + range, ``max_item_count=PAGE_SIZE``: every page must contain + exactly ``PAGE_SIZE`` items (except possibly the last one), no + duplicates across pages, and the last page's continuation must be + ``None``. + + This is the path the vast majority of feedRanges follow. It does + NOT exercise the multi-overlap fan-out branch; a regression here + means the single-overlap path itself is broken. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if not partitions: + pytest.skip("Container has no physical partitions") + + # Pick the first partition that holds enough docs to drive + # multiple PAGE_SIZE pages. + chosen_pp = None + for (mn, mx) in partitions: + if _count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION: + chosen_pp = (mn, mx) + break + if chosen_pp is None: + pytest.skip("No single partition populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + single = _crossing_feed_range(chosen_pp[0], chosen_pp[1]) + ground_truth = _ids_via_per_partition_scan(container, [chosen_pp]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=single, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + # Every page except the last one is exactly PAGE_SIZE; the last + # page is at most PAGE_SIZE. + for idx, page in enumerate(pages): + if idx < len(pages) - 1: + assert len(page) == PAGE_SIZE, ( + f"page {idx} returned {len(page)} items, expected " + f"exactly {PAGE_SIZE} (only the last page is allowed " + "to be short on the single-overlap path)") + else: + assert len(page) <= PAGE_SIZE + + # No duplicates and full coverage of the partition. + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"single-partition path returned duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"single-partition coverage mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + # After the last page, the continuation token must be None + # (composite drained -> caller's loop terminates correctly). + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages; got " + f"{pager.continuation_token!r}") + + # ------------------------------------------------------------------ # + # Partition-key caller shapes (full key and prefix key) + # ------------------------------------------------------------------ # + def test_full_partition_key_query_pagination_resume(self): + """Full hierarchical partition-key query resumes correctly by continuation. + + This uses a dedicated MultiHash container and a full key value so the + query stays scoped to one logical partition while still exercising + pagination + resume on the partition_key path. + """ + db = _client().get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionFullPK-" + str(uuid.uuid4()) + created_container = db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + full_key = ['CA', 'Oxnard', '93033'] + for i in range(25): + created_container.upsert_item({ + 'id': f'full-pk-doc-{i:03d}', + 'state': full_key[0], + 'city': full_key[1], + 'zipcode': full_key[2], + 'value': i, + }) + for i in range(5): + created_container.upsert_item({ + 'id': f'other-doc-{i:03d}', + 'state': 'WA', + 'city': 'Seattle', + 'zipcode': f'98{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=full_key, + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page = list(next(pager)) + assert first_page + token = pager.continuation_token + assert token + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resumed_remaining_ids = [] + for page in query_iterable.by_page(token): + resumed_remaining_ids.extend(item['id'] for item in page) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] for item in created_container.query_items(query=query, partition_key=full_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + finally: + db.delete_container(created_container.id) + + def test_prefix_partition_key_query_pagination_resume(self): + """Prefix hierarchical partition-key query resumes correctly by continuation. + + The caller provides only the first level (``['CA']``). The query spans + multiple descendants under that prefix and must preserve continuation + correctness on resume. + """ + db = _client().get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionPrefixPK-" + str(uuid.uuid4()) + created_container = db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + for i in range(30): + created_container.upsert_item({ + 'id': f'ca-doc-{i:03d}', + 'state': 'CA', + 'city': f'city-{i % 5}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + for i in range(6): + created_container.upsert_item({ + 'id': f'wa-doc-{i:03d}', + 'state': 'WA', + 'city': f'city-{i % 2}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=['CA'], + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page = list(next(pager)) + assert first_page + token = pager.continuation_token + assert token + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resumed_remaining_ids = [] + for page in query_iterable.by_page(token): + resumed_remaining_ids.extend(item['id'] for item in page) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] for item in created_container.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + finally: + db.delete_container(created_container.id) + + + # ------------------------------------------------------------------ # + # Two-partition feed_range + # ------------------------------------------------------------------ # + def test_two_partition_feed_range(self): + """Construct a feed_range that overlaps two adjacent physical + partitions and pin three invariants: + + (a) per-page item count ≤ ``max_item_count`` (the fan-out must + not concatenate per-overlap responses into one oversized + logical page), + (b) no duplicate ids across pages (each overlap's outbound + continuation must be preserved so the next page resumes + instead of restarting from offset 0), + (c) the union of ids returned matches the union of ids from + independent per-partition scans (no missing items). + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + # Find the first adjacent pair where both partitions hold enough docs + # to drive ≥ 3 pages under PAGE_SIZE = 5. + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + # (a) every page must respect max_item_count + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + # (b) no duplicates across pages + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + # (c) no missing items vs ground truth + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + def test_two_partition_feed_range_count_aggregate_pagination(self): + """Run a VALUE aggregate through a two-partition crossing feed_range. + + Guards aggregate-specific invariants on the multi-overlap path: + (a) each logical page still respects ``max_item_count``, + (b) partial aggregate fragments are merged client-side (one scalar + result after draining), + (c) merged count matches an independent per-partition scan baseline. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Ground truth from independent per-partition scans, not aggregate path. + expected_count = len(_ids_via_per_partition_scan(container, [chosen[0], chosen[1]])) + + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + for page in pager: + items = list(page) + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_count, ( + "merged COUNT result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_count}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + + @pytest.mark.parametrize("merge_error_type", [TypeError, KeyError]) + def test_two_partition_feed_range_merge_fallback_preserves_rows( + self, monkeypatch, caplog, merge_error_type + ): + """Force merge failures and verify fallback extends docs without loss.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + merge_call_count = 0 + + def _raising_merge(*_args, **_kwargs): + nonlocal merge_call_count + merge_call_count += 1 + raise merge_error_type("injected-merge-failure") + + monkeypatch.setattr(_base, "_merge_query_results", _raising_merge) + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + assert merge_call_count > 0 + assert any( + "Falling back to non-aggregate merge after aggregate merge failure" in record.getMessage() + for record in caplog.records + ), "Expected warning log for merge fallback path" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"fallback path produced duplicates: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + assert unique == ground_truth, ( + f"fallback path dropped/added items: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + def test_exception_during_post_preserves_resume_checkpoint(self): + """Inject a POST failure mid-query and verify the call site stamps + an outbound continuation that resumes from the last successful slice. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + client_conn = container.client_connection + original_post = client_conn._CosmosClientConnection__Post + call_count = 0 + + def _failing_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("injected-post-failure") + return original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + with pytest.raises(RuntimeError, match="injected-post-failure"): + _ = list(next(pager)) + finally: + client_conn._CosmosClientConnection__Post = original_post + + token = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert token, "Expected continuation checkpoint to be stamped on POST failure" + decoded = _decode_token(token) + assert decoded is not None + assert decoded["c"][0]["min"] == chosen[1][0], ( + "Checkpoint should resume from the second sub-range after the first " + "sub-range completed successfully before failure." + ) + + def test_explode_iteration_guard_raises_in_query_loop(self, monkeypatch): + """Drive the live ``__QueryFeed`` explode loop until the runtime guard raises.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + # Force every routing lookup to look like an unresolved post-split overlap. + def _always_multi_overlap(_rid, feed_ranges, _opts): + head = feed_ranges[0] + return [ + {"id": "left", "minInclusive": head.min, "maxExclusive": head.max}, + {"id": "right", "minInclusive": head.min, "maxExclusive": head.max}, + ] + + monkeypatch.setattr( + client_conn._routing_map_provider, "get_overlapping_ranges", _always_multi_overlap + ) + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS", 2 + ) + + with pytest.raises(RuntimeError) as excinfo: + list( + container.query_items( + query="SELECT * FROM c", feed_range=crossing, max_item_count=PAGE_SIZE + ).by_page() + ) + assert "split re-resolution" in str(excinfo.value) + + def test_no_progress_guard_logs_warning_in_query_loop(self, monkeypatch, caplog): + """Drive repeated empty pages with unchanged continuation and assert warning emission.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + post_call_count = 0 + + def _stalled_post(*_args, **_kwargs): + nonlocal post_call_count + post_call_count += 1 + continuation = "stalled-token" if post_call_count <= 3 else None + return {"Documents": []}, {http_constants.HttpHeaders.Continuation: continuation} + + monkeypatch.setattr(client_conn, "_CosmosClientConnection__Post", _stalled_post) + monkeypatch.setattr( + "azure.cosmos._cosmos_client_connection._MAX_CONSECUTIVE_NO_PROGRESS_PAGES", 2 + ) + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + list( + container.query_items( + query="SELECT * FROM c", feed_range=crossing, max_item_count=PAGE_SIZE + ).by_page() + ) + + assert post_call_count >= 3 + assert any( + "same continuation token" in record.getMessage() for record in caplog.records + ), "Expected warning log from no-progress guard" + + # ------------------------------------------------------------------ # + # Three-way overlap (synthetic, wider fan-out) + # ------------------------------------------------------------------ # + def test_three_way_overlap(self): + """Same shape as ``test_two_partition_feed_range`` but with a + ``feed_range`` that overlaps **three** adjacent physical partitions. + Wider fan-out exercises the same three guarantees as the + two-partition test on a larger overlap set. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 3: + pytest.skip("Need a container with ≥ 3 physical partitions") + + chosen: Optional[List[Tuple[str, str]]] = None + for i in range(len(partitions) - 2): + triple = partitions[i:i + 3] + if all(_count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION + for mn, mx in triple): + chosen = triple + break + if chosen is None: + pytest.skip("No three adjacent partitions all populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + assert chosen is not None # narrows for type checkers; pytest.skip raises + crossing = _crossing_feed_range(chosen[0][0], chosen[2][1]) + ground_truth = _ids_via_per_partition_scan(container, chosen) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated; sizes={[len(p) for p in pages]}; " + f"oversized={oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + # ------------------------------------------------------------------ # + # Post-split resume (slow; requires a real partition split) + # ------------------------------------------------------------------ # + @pytest.mark.cosmosSplit + def test_post_split_resume(self): + """End-to-end "the routing layout changed underneath a saved + continuation token" scenario: + + 1. Construct a 2-overlap crossing feed_range under the *current* + routing map; drain page 1 and save the continuation token. + 2. Trigger a real partition split (``trigger_split``) so the + container's physical partition count grows. The same EPK + ``{min, max}`` interval now overlaps a different (≥ 2) set of + physical partitions. + 3. Resume with the saved continuation token + the same + ``feed_range``. Drain remaining pages. + 4. Assert: combined ids across page 1 + post-split pages are + unique and equal the union of a fresh per-partition scan over + the same EPK interval. + + On the post-split resume path, the saved continuation must remain + valid (or be safely restarted) under the new physical layout - the + combined ids across the split boundary must still be unique and + cover the same EPK interval. + """ + container = _get_container() + partitions_before = _sorted_partition_ranges(container) + if len(partitions_before) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions_before) - 1): + p0, p1 = partitions_before[i], partitions_before[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 only and save the outbound continuation. + pager_pre = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1 = list(next(pager_pre)) + ct_after_page_1 = pager_pre.continuation_token + page_1_ids = [item["id"] for item in page_1] + assert ct_after_page_1, ( + "expected a non-empty continuation token after page 1; the " + "feed_range overlaps two partitions and the first page should " + "not have drained the whole interval") + + # Step 2 — trigger a real split. This is the slow step (up to 10 min + # for the offer-replace operation to complete). + target_throughput = max(REPRO_THROUGHPUT * 2, 60000) + try: + test_config.TestConfig.trigger_split(container, target_throughput) + except unittest.SkipTest: + raise + # Allow the routing map a brief settling period after the split + # completes, then force a refresh so the SDK sees the new layout. + time.sleep(10) + list(container.read_feed_ranges(force_refresh=True)) + + # Step 3 — resume with the saved continuation, same feed_range. + pager_post = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=ct_after_page_1) + post_pages, post_ids = _drain_pages(pager_post) + + combined_ids = page_1_ids + post_ids + unique = set(combined_ids) + duplicate_count = len(combined_ids) - len(unique) + # When the parent's backend continuation is dropped during the + # post-split explode (children won't accept the parent's bc), + # the children restart at offset 0 of their slice. The lower + # child can therefore re-emit up to PAGE_SIZE rows that page 1 + # already returned. The strict no-dup invariant only holds when + # page 1 happened to fully drain the parent slice; the + # bounded-replay invariant always holds and is what we assert + # here. The strict no-loss / no-out-of-range guarantee is still + # enforced by the ``unique == ground_truth`` check below. + assert duplicate_count <= PAGE_SIZE, ( + f"unexpected duplicate count across the split boundary: " + f"{len(combined_ids)} ids returned across page 1 + " + f"{len(post_pages)} post-split page(s), {len(unique)} distinct, " + f"{duplicate_count} duplicate(s) (max allowed: {PAGE_SIZE}).") + + oversized = [(i, len(p)) for i, p in enumerate(post_pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated post-split; sizes=" + f"{[len(p) for p in post_pages]}; oversized={oversized}.") + + # Re-derive ground truth against the post-split routing map. + partitions_after = _sorted_partition_ranges(container) + post_split_overlaps = [(mn, mx) for (mn, mx) in partitions_after + if min(p1_max, mx) > max(p0_min, mn)] + ground_truth = _ids_via_per_partition_scan(container, post_split_overlaps) + assert unique == ground_truth, ( + f"item-set mismatch after post-split resume: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + # ------------------------------------------------------------------ # + # Legacy opaque token compatibility + # ------------------------------------------------------------------ # + def test_legacy_opaque_token_compat(self, caplog): + """Use an opaque continuation token (not base64 JSON, not v=1). + The query restarts from the beginning. + + Asserts: + (a) no exception is raised on the resume call, + (b) all batches restart from the beginning (the union of ids + returned equals the per-partition ground truth), + (c) every page respects ``max_item_count``, + (d) pagination runs to completion (final continuation is None). + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + # Opaque continuation string that does not match the structured + # token format. + # cspell:ignore AOXB BAAAAAAAAAA EAAAAFAAAA + legacy_token = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567#FPC:AgEAAAAFAAAA" + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=legacy_token) + + # (a) no exception on iteration + pages, all_ids = _drain_pages(pager) + + assert any( + "not in the supported structured format" in record.getMessage() + for record in caplog.records + ), "Expected warning log when a non-structured continuation token is supplied" + assert all( + legacy_token not in record.getMessage() for record in caplog.records + ), "Warning log must not include raw continuation token text" + + # (c) page-size limit respected + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated under legacy-token resume; " + f"sizes={[len(p) for p in pages]}; oversized={oversized}") + + # (b) full restart from offset 0 -> coverage matches ground truth, + # no duplicates + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"legacy-token restart produced duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"legacy-token restart coverage mismatch: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + # (d) pagination ran to completion + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages on " + f"legacy-token restart; got {pager.continuation_token!r}") + + # ------------------------------------------------------------------ # + # Identity-fingerprint mismatch rejection (live half) + # ------------------------------------------------------------------ # + def test_token_identity_mismatch_rejected(self): + """Round-trip a token through ``query_items`` then replay it + against (a) a different query text, (b) a different parameter + value, and (c) a different ``feed_range``. Each replay must raise + ``ValueError`` from the call-site validation in ``__QueryFeed`` + with a message that names the failing field. + + The unit tests in ``test_feed_range_continuation_token.py`` cover + the hash-inequality contract; this test covers the live raise + path through the SDK's actual query pipeline. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 of a parameterized query and save the + # outbound continuation. The token's qh/frh fingerprints encode + # this query + this feed_range. + # Use bracket notation: ``value`` is a reserved word in Cosmos SQL. + original_query = "SELECT * FROM c WHERE c[\"value\"] >= @v" + original_params = [{"name": "@v", "value": 0}] + pager = container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + _ = list(next(pager)) + token = pager.continuation_token + assert token, ("expected a non-empty continuation after page 1; " + "the test cannot exercise resume validation otherwise") + + # (a) Different query TEXT — qh mismatch. + with pytest.raises(ValueError) as excinfo_q: + list(container.query_items( + query={"query": "SELECT * FROM c WHERE c[\"value\"] >= @v AND c.id != ''", + "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + msg_q = str(excinfo_q.value).lower() + assert "query" in msg_q, ( + "ValueError on query-text mismatch must name the failing " + f"field; got: {excinfo_q.value!r}") + + # (b) Different parameter VALUE — same query text, different qh. + with pytest.raises(ValueError) as excinfo_p: + list(container.query_items( + query={"query": original_query, + "parameters": [{"name": "@v", "value": 999999}]}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_p.value).lower(), ( + "ValueError on parameter-value mismatch must name the query " + f"field; got: {excinfo_p.value!r}") + + # (c) Different feed_range — frh mismatch. Use a single-partition + # sub-range of the original crossing range (still inside the same + # collection so cr matches; only frh differs). + single_p0 = _crossing_feed_range(chosen[0][0], chosen[0][1]) + with pytest.raises(ValueError) as excinfo_f: + list(container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=single_p0, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "feed_range" in str(excinfo_f.value).lower(), ( + "ValueError on feed_range mismatch must name the feed_range " + f"field; got: {excinfo_f.value!r}") diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py new file mode 100644 index 000000000000..83ffc44fe06c --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py @@ -0,0 +1,811 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Async multi-partition feed_range tests.""" + +import asyncio +import unittest +import uuid +from typing import Iterable, List, Optional, Tuple + +import pytest +import pytest_asyncio + +import test_config +from azure.cosmos import _base +from azure.cosmos import http_constants +from azure.cosmos.aio import CosmosClient +from azure.cosmos._routing.feed_range_continuation import _decode_token +from azure.cosmos.partition_key import PartitionKey + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID + +REPRO_CONTAINER_ID = "FeedRangeMultiPartitionAsync-" + str(uuid.uuid4()) +REPRO_PARTITION_KEY = "pk" +REPRO_THROUGHPUT = CONFIG.THROUGHPUT_FOR_5_PARTITIONS +REPRO_DOC_COUNT = 200 +PAGE_SIZE = 5 +MIN_DOCS_PER_PARTITION = 15 + + +def _client() -> CosmosClient: + return CosmosClient(HOST, KEY) + + +def _get_container(client: CosmosClient): + return client.get_database_client(DATABASE_ID).get_container_client(REPRO_CONTAINER_ID) + + +async def _sorted_partition_ranges(container) -> List[Tuple[str, str]]: + pairs: List[Tuple[str, str]] = [] + async for fr in container.read_feed_ranges(): + r = fr["Range"] + pairs.append((r["min"], r["max"])) + pairs.sort(key=lambda p: p[0]) + return pairs + + +async def _count_in_range(container, range_min: str, range_max: str) -> int: + fr = test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + items = [it async for it in container.query_items( + query="SELECT VALUE COUNT(1) FROM c", feed_range=fr)] + return items[0] if items else 0 + + +def _crossing_feed_range(range_min: str, range_max: str): + return test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + + +async def _ids_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + ground_truth = set() + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + async for item in container.query_items(query="SELECT c.id FROM c", feed_range=fr): + ground_truth.add(item["id"]) + return ground_truth + + +async def _values_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + values = [] + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + async for value in container.query_items(query='SELECT VALUE c["value"] FROM c', feed_range=fr): + values.append(value) + return values + + +async def _drain_pages(pager) -> Tuple[List[List[dict]], List[str]]: + pages: List[List[dict]] = [] + all_ids: List[str] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + all_ids.extend(it["id"] for it in items) + return pages, all_ids + + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown_async(): + client = _client() + db = client.get_database_client(DATABASE_ID) + container = await db.create_container_if_not_exists( + id=REPRO_CONTAINER_ID, + partition_key=PartitionKey(path="/" + REPRO_PARTITION_KEY, kind="Hash"), + offer_throughput=REPRO_THROUGHPUT) + for i in range(REPRO_DOC_COUNT): + await container.upsert_item({ + REPRO_PARTITION_KEY: f"pk-{i:04d}", + "id": f"doc-{i:04d}", + "value": i, + }) + yield + try: + await db.delete_container(REPRO_CONTAINER_ID) + except Exception: # pylint: disable=broad-except + pass + await client.close() + + +@pytest.mark.cosmosQuery +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown_async") +class TestFeedRangeMultiPartitionAsync: + """Async end-to-end tests for feed_range queries that overlap multiple + physical partitions.""" + + # ------------------------------------------------------------------ # + # Single-partition control + # ------------------------------------------------------------------ # + async def test_single_partition_feed_range_async(self): + """Single-partition regression guard.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if not partitions: + pytest.skip("Container has no physical partitions") + + chosen_pp = None + for (mn, mx) in partitions: + if await _count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION: + chosen_pp = (mn, mx) + break + if chosen_pp is None: + pytest.skip("No single partition populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + single = _crossing_feed_range(chosen_pp[0], chosen_pp[1]) + ground_truth = await _ids_via_per_partition_scan(container, [chosen_pp]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=single, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + for idx, page in enumerate(pages): + if idx < len(pages) - 1: + assert len(page) == PAGE_SIZE, ( + f"page {idx} returned {len(page)} items, expected " + f"exactly {PAGE_SIZE}") + else: + assert len(page) <= PAGE_SIZE + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"single-partition path returned duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"single-partition coverage mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + + # ------------------------------------------------------------------ # + # Two-partition feed_range + # ------------------------------------------------------------------ # + async def test_two_partition_feed_range_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages: {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + async def test_two_partition_feed_range_count_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_count = len(await _ids_via_per_partition_scan(container, [chosen[0], chosen[1]])) + + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_count, ( + "merged COUNT result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_count}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + @pytest.mark.parametrize("merge_error_type", [TypeError, KeyError]) + async def test_two_partition_feed_range_merge_fallback_preserves_rows_async( + self, monkeypatch, caplog, merge_error_type + ): + """Force merge failures and verify fallback extends docs without loss.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with >= " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + merge_call_count = 0 + + def _raising_merge(*_args, **_kwargs): + nonlocal merge_call_count + merge_call_count += 1 + raise merge_error_type("injected-merge-failure") + + monkeypatch.setattr(_base, "_merge_query_results", _raising_merge) + + with caplog.at_level("WARNING", logger="azure.cosmos.aio._cosmos_client_connection_async"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + assert merge_call_count > 0 + assert any( + "Falling back to non-aggregate merge after aggregate merge failure" in record.getMessage() + for record in caplog.records + ), "Expected warning log for merge fallback path" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"fallback path produced duplicates: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + assert unique == ground_truth, ( + f"fallback path dropped/added items: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + async def test_two_partition_feed_range_min_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_values = await _values_via_per_partition_scan(container, [chosen[0], chosen[1]]) + expected_min = min(expected_values) + + pager = container.query_items( + query='SELECT VALUE MIN(c["value"]) FROM c', + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_min, ( + "merged MIN result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_min}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + async def test_two_partition_feed_range_max_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_values = await _values_via_per_partition_scan(container, [chosen[0], chosen[1]]) + expected_max = max(expected_values) + + pager = container.query_items( + query='SELECT VALUE MAX(c["value"]) FROM c', + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_max, ( + "merged MAX result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_max}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + async def test_exception_during_post_preserves_resume_checkpoint_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + client_conn = container.client_connection + original_post = client_conn._CosmosClientConnection__Post + call_count = 0 + + async def _failing_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("injected-post-failure") + return await original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + with pytest.raises(RuntimeError, match="injected-post-failure"): + page_iter = await pager.__anext__() + _ = [it async for it in page_iter] + finally: + client_conn._CosmosClientConnection__Post = original_post + + token = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert token, "Expected continuation checkpoint to be stamped on POST failure" + decoded = _decode_token(token) + assert decoded is not None + assert decoded["c"][0]["min"] == chosen[1][0], ( + "Checkpoint should resume from the second sub-range after the first " + "sub-range completed successfully before failure." + ) + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Three-way overlap + # ------------------------------------------------------------------ # + async def test_three_way_overlap_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 3: + pytest.skip("Need a container with ≥ 3 physical partitions") + + chosen: Optional[List[Tuple[str, str]]] = None + for i in range(len(partitions) - 2): + triple = partitions[i:i + 3] + ok = True + for mn, mx in triple: + if await _count_in_range(container, mn, mx) < MIN_DOCS_PER_PARTITION: + ok = False + break + if ok: + chosen = triple + break + if chosen is None: + pytest.skip("No three adjacent partitions all populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + assert chosen is not None # narrows for type checkers; pytest.skip raises + crossing = _crossing_feed_range(chosen[0][0], chosen[2][1]) + ground_truth = await _ids_via_per_partition_scan(container, chosen) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated; sizes={[len(p) for p in pages]}; " + f"oversized={oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Post-split resume (slow) + # ------------------------------------------------------------------ # + @pytest.mark.cosmosSplit + async def test_post_split_resume_async(self): + client = _client() + try: + container = _get_container(client) + partitions_before = await _sorted_partition_ranges(container) + if len(partitions_before) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions_before) - 1): + p0, p1 = partitions_before[i], partitions_before[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 only and save the outbound continuation. + pager_pre = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1_iter = await pager_pre.__anext__() + page_1 = [it async for it in page_1_iter] + ct_after_page_1 = pager_pre.continuation_token + page_1_ids = [item["id"] for item in page_1] + assert ct_after_page_1, ( + "expected a non-empty continuation token after page 1") + + # Step 2 — trigger a real split. + target_throughput = max(REPRO_THROUGHPUT * 2, 60000) + try: + await test_config.TestConfig.trigger_split_async(container, target_throughput) + except unittest.SkipTest: + raise + await asyncio.sleep(10) + _ = [fr async for fr in container.read_feed_ranges(force_refresh=True)] + + # Step 3 — resume with the saved continuation, same feed_range. + pager_post = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=ct_after_page_1) + post_pages, post_ids = await _drain_pages(pager_post) + + combined_ids = page_1_ids + post_ids + unique = set(combined_ids) + duplicate_count = len(combined_ids) - len(unique) + # When the parent's backend continuation is dropped during + # the post-split explode (children won't accept the + # parent's bc), the children restart at offset 0 of their + # slice. The lower child can therefore re-emit up to + # PAGE_SIZE rows that page 1 already returned. The strict + # no-dup invariant only holds when page 1 happened to fully + # drain the parent slice; the bounded-replay invariant + # always holds and is what we assert here. The strict + # no-loss / no-out-of-range guarantee is still enforced by + # the ``unique == ground_truth`` check below. + assert duplicate_count <= PAGE_SIZE, ( + f"unexpected duplicate count across the split boundary: " + f"{len(combined_ids)} ids returned across page 1 + " + f"{len(post_pages)} post-split page(s), {len(unique)} distinct, " + f"{duplicate_count} duplicate(s) (max allowed: {PAGE_SIZE}).") + + oversized = [(i, len(p)) for i, p in enumerate(post_pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated post-split; sizes=" + f"{[len(p) for p in post_pages]}; oversized={oversized}.") + + partitions_after = await _sorted_partition_ranges(container) + post_split_overlaps = [(mn, mx) for (mn, mx) in partitions_after + if min(p1_max, mx) > max(p0_min, mn)] + ground_truth = await _ids_via_per_partition_scan(container, post_split_overlaps) + assert unique == ground_truth, ( + f"item-set mismatch after post-split resume: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Legacy opaque token compatibility + # ------------------------------------------------------------------ # + async def test_legacy_opaque_token_compat_async(self, caplog): + """Use an opaque continuation token and verify restart behavior.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + # cspell:ignore AOXB BAAAAAAAAAA EAAAAFAAAA + legacy_token = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567#FPC:AgEAAAAFAAAA" + + with caplog.at_level("WARNING", logger="azure.cosmos.aio._cosmos_client_connection_async"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=legacy_token) + pages, all_ids = await _drain_pages(pager) + + assert any( + "not in the supported structured format" in record.getMessage() + for record in caplog.records + ), "Expected warning log when a non-structured continuation token is supplied" + assert all( + legacy_token not in record.getMessage() for record in caplog.records + ), "Warning log must not include raw continuation token text" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated under legacy-token resume; " + f"sizes={[len(p) for p in pages]}; oversized={oversized}") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"legacy-token restart produced duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"legacy-token restart coverage mismatch: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages on " + f"legacy-token restart; got {pager.continuation_token!r}") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Identity-fingerprint mismatch rejection (live half) + # ------------------------------------------------------------------ # + async def test_token_identity_mismatch_rejected_async(self): + """Live identity-mismatch rejection test.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Use bracket notation: ``value`` is a reserved word in Cosmos SQL. + original_query = "SELECT * FROM c WHERE c[\"value\"] >= @v" + original_params = [{"name": "@v", "value": 0}] + pager = container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1_iter = await pager.__anext__() + _ = [it async for it in page_1_iter] + token = pager.continuation_token + assert token, ("expected a non-empty continuation after page 1; " + "test cannot exercise resume validation otherwise") + + async def _drain(p): + async for page in p: + _ = [it async for it in page] + + # (a) Different query TEXT. + with pytest.raises(ValueError) as excinfo_q: + await _drain(container.query_items( + query={"query": "SELECT * FROM c WHERE c[\"value\"] >= @v AND c.id != ''", + "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_q.value).lower(), ( + f"ValueError on query-text mismatch must name the failing " + f"field; got: {excinfo_q.value!r}") + + # (b) Different parameter VALUE. + with pytest.raises(ValueError) as excinfo_p: + await _drain(container.query_items( + query={"query": original_query, + "parameters": [{"name": "@v", "value": 999999}]}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_p.value).lower(), ( + f"ValueError on parameter-value mismatch must name the " + f"query field; got: {excinfo_p.value!r}") + + # (c) Different feed_range. + single_p0 = _crossing_feed_range(chosen[0][0], chosen[0][1]) + with pytest.raises(ValueError) as excinfo_f: + await _drain(container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=single_p0, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "feed_range" in str(excinfo_f.value).lower(), ( + f"ValueError on feed_range mismatch must name the " + f"feed_range field; got: {excinfo_f.value!r}") + finally: + await client.close() + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py index 0a4e19049886..0a3555a908cf 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import unittest import uuid +from unittest.mock import patch import pytest from azure.core.exceptions import ServiceRequestError, ServiceResponseError @@ -38,94 +39,68 @@ def setUpClass(cls): cls.created_database = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID) + def _setup_read_regions(self, location_cache, regions): + """Set all read region attributes consistently so update_location_cache() recalculates correctly.""" + location_cache.account_read_locations = regions + location_cache.account_read_regional_routing_contexts_by_location = { + r: self.REGIONAL_ENDPOINT for r in regions} + location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + location_cache.effective_preferred_locations = regions + + def _setup_write_regions(self, location_cache, regions): + """Set all write region attributes consistently so update_location_cache() recalculates correctly.""" + location_cache.account_write_locations = regions + location_cache.account_write_regional_routing_contexts_by_location = { + r: self.REGIONAL_ENDPOINT for r in regions} + location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + def test_service_request_retry_policy(self): mock_client = CosmosClient(self.host, self.masterKey) db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function - # Now we test with a query operation, iterating through items sends request without request object - # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt - # equal to the available read region locations + # Test read with IgnoreQuery mock (allows query/pkranges requests through) + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(_retry_utility.ExecuteFunction) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.read_item(created_item['id'], created_item['pk']) + assert mf.counter == expected_counter - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + # Now we test with a query operation + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - items = list(container.query_items(query="SELECT * FROM c", partition_key=created_item['pk'])) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + list(container.query_items(query="SELECT * FROM c", partition_key=created_item['pk'])) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) expected_counter = len(original_location_cache.write_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - # Should retry twice in each region + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function def test_service_response_retry_policy(self): mock_client = CosmosClient(self.host, self.masterKey) @@ -133,71 +108,41 @@ def test_service_response_retry_policy(self): container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] - try: - # Mock the function to return the ServiceResponseException we retry - mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, self.original_execute_function) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, _retry_utility.ExecuteFunction) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 3 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_read_regions(original_location_cache, [self.REGION1]) + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 1 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request with retry write enabled - which should retry once - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) assert mf.counter == 2 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function def test_service_request_connection_retry_policy(self): # Mock the client retry policy to see the same-region retries that happen there @@ -251,38 +196,32 @@ def test_global_endpoint_manager_retry(self): # - GetDatabaseAccountStub allows us to receive any number of endpoints for that call independent of account used exception = ServiceRequestError("mock exception") exception.exc_type = Exception - self.original_get_database_account_stub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub connection_retry_policy = test_config.MockConnectionRetryPolicy(resource_type="docs", error=exception) - mock_client = CosmosClient(self.host, self.masterKey, connection_retry_policy=connection_retry_policy, - preferred_locations=[self.REGION1, self.REGION2]) - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + mock_client = CosmosClient(self.host, self.masterKey, connection_retry_policy=connection_retry_policy, + preferred_locations=[self.REGION1, self.REGION2]) + db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) - try: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + with pytest.raises(ServiceRequestError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert connection_retry_policy.counter == 3 # 4 total in region retries assert len(connection_retry_policy.request_endpoints) == 4 - finally: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub # Now we try with a read request - reset the policy to reset the counter connection_retry_policy.request_endpoints = [] - try: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - container.read_item("some_id", "some_pk") - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + with pytest.raises(ServiceRequestError): + container.read_item("some_id", "some_pk") assert connection_retry_policy.counter == 3 # 4 total requests in each main region (preferred read region 1 -> preferred read region 2) assert len(connection_retry_policy.request_endpoints) == 8 - finally: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub class MockExecuteServiceRequestException(object): def __init__(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py index eb3082fdc976..92419c71fb34 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py @@ -3,6 +3,7 @@ import unittest import uuid +from unittest.mock import patch import pytest from aiohttp.client_exceptions import (ClientConnectionError, ClientConnectionResetError, @@ -10,7 +11,7 @@ from azure.core.exceptions import ServiceRequestError, ServiceResponseError import test_config -from azure.cosmos import DatabaseAccount, _location_cache +from azure.cosmos import DatabaseAccount from azure.cosmos._location_cache import RegionalRoutingContext from azure.cosmos._request_object import RequestObject from azure.cosmos.aio import CosmosClient, _retry_utility_async, _global_endpoint_manager_async @@ -47,6 +48,33 @@ async def asyncTearDown(self): self.connectionPolicy.ConnectionRetryConfiguration = None await self.client.close() + async def _cancel_background_refresh_task(self, client): + """Cancel the GEM's background health check task to prevent it from + overwriting test cache mutations via update_location_cache().""" + gem = client.client_connection._global_endpoint_manager + if gem.refresh_task and not gem.refresh_task.done(): + gem.refresh_task.cancel() + try: + await gem.refresh_task + except BaseException: + pass + gem.refresh_task = None + + def _setup_read_regions(self, location_cache, regions): + """Set all read region attributes consistently so update_location_cache() recalculates correctly.""" + location_cache.account_read_locations = regions + location_cache.account_read_regional_routing_contexts_by_location = { + r: self.REGIONAL_ENDPOINT for r in regions} + location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + location_cache.effective_preferred_locations = regions + + def _setup_write_regions(self, location_cache, regions): + """Set all write region attributes consistently so update_location_cache() recalculates correctly.""" + location_cache.account_write_locations = regions + location_cache.account_write_regional_routing_contexts_by_location = { + r: self.REGIONAL_ENDPOINT for r in regions} + location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + async def test_service_request_retry_policy_async(self): # ServiceRequestErrors will always retry, and will retry once per preferred region async with CosmosClient(self.host, self.masterKey) as mock_client: @@ -54,89 +82,49 @@ async def test_service_request_retry_policy_async(self): container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - # Now we test with a query operation, iterating through items sends request without request object - # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt - # equal to the available read region locations + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(_retry_utility_async.ExecuteFunctionAsync) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): + await container.read_item(created_item['id'], created_item['pk']) + assert mf.counter == expected_counter - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + # Now we test with a query operation + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - items = [item async for item in container.query_items(query="SELECT * FROM c", - partition_key=created_item['pk'])] - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): + items = [item async for item in container.query_items(query="SELECT * FROM c", + partition_key=created_item['pk'])] assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): + await container.read_item(created_item['id'], created_item['pk']) assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) expected_counter = len(original_location_cache.write_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_service_response_retry_policy_async(self): # For ServiceResponseErrors, we only do cross region retries on read requests or on ClientConnectionErrors @@ -146,92 +134,54 @@ async def test_service_response_retry_policy_async(self): container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] - try: - # Mock the function to return the ClientConnectionError we retry - mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, - None, self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, + None, _retry_utility_async.ExecuteFunctionAsync) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 3 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_read_regions(original_location_cache, [self.REGION1]) + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # If we do a write request with a ClientConnectionError, # we will do cross-region retries like with read requests - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we try it out with a write request with retry write enabled - which should retry once - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_service_request_connection_retry_policy_async(self): # Mock the client retry policy to see the same-region retries that happen there @@ -285,125 +235,91 @@ async def test_service_response_connection_retry_policy_async(self): async def test_service_response_errors_async(self): # Test for errors that are subclasses of ClientConnectionError for write requests - # Save the original ExecuteAsyncFunction function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) await container.read() + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + # For writes, set only the derived state directly since test_service_response_errors + # relies on mark_endpoint_unavailable -> update_location_cache() reducing write regions original_location_cache.account_write_locations = [self.REGION1, self.REGION2] original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} - try: - # Start with a normal ServiceResponseException with no special casing - mf = self.MockExecuteServiceResponseException(AttributeError, AttributeError()) - _retry_utility_async.ExecuteFunctionAsync = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + + # Start with a normal ServiceResponseException with no special casing + mf = self.MockExecuteServiceResponseException(AttributeError, AttributeError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we test the base ClientConnectionError to see in-region retry - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 assert len(original_location_cache.location_unavailability_info_by_endpoint) == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function await container.read() # We send another request to the same error - since we marked one of the two in-region endpoints unavailable we don't retry - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ClientConnectionResetError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionResetError, ClientConnectionResetError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionResetError, ClientConnectionResetError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 assert len(original_location_cache.location_unavailability_info_by_endpoint) == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ServerConnectionError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ServerConnectionError, ServerConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ServerConnectionError, ServerConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ClientOSError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientOSError, ClientOSError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientOSError, ClientOSError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_global_endpoint_manager_retry_async(self): # For this test we mock both the ConnectionRetryPolicy and the GetDatabaseAccountStub @@ -411,40 +327,29 @@ async def test_global_endpoint_manager_retry_async(self): # - GetDatabaseAccountStub allows us to receive any number of endpoints for that call independent of account used exception = ServiceRequestError("mock exception") exception.exc_type = Exception - self.original_get_database_account_stub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub connection_policy = self.connectionPolicy connection_retry_policy = test_config.MockConnectionRetryPolicyAsync(resource_type="docs", error=exception) connection_policy.ConnectionRetryConfiguration = connection_retry_policy - async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy, - preferred_locations=[self.REGION1, self.REGION2]) as mock_client: - db = mock_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - - try: - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager_async._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy, + preferred_locations=[self.REGION1, self.REGION2]) as mock_client: + db = mock_client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_ID) + + with pytest.raises((ServiceRequestError, CosmosHttpResponseError)): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert connection_retry_policy.counter == 3 # 4 total in region retries assert len(connection_retry_policy.request_endpoints) == 4 - except CosmosHttpResponseError as e: - print(e) - finally: - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub - - # Now we try with a read request - reset the policy to reset the counter - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - connection_retry_policy.request_endpoints = [] - try: - await container.read_item("some_id", "some_pk") - pytest.fail("Exception was not raised.") - except ServiceRequestError: + + # Now we try with a read request - reset the policy to reset the counter + connection_retry_policy.request_endpoints = [] + with pytest.raises(ServiceRequestError): + await container.read_item("some_id", "some_pk") assert connection_retry_policy.counter == 3 # 4 total requests in each main region (preferred read region 1 -> preferred read region 2) assert len(connection_retry_policy.request_endpoints) == 8 - finally: - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub class MockExecuteServiceRequestException(object): def __init__(self):