From 5266479c837c3fa3ed26a038afdccd0cc5a0d497 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 17 Feb 2026 12:16:09 -0600 Subject: [PATCH 01/18] fix - fixing global endpoint issue (#45200) * fix - fixing global endpoint issue * fix - updating CHANGELOG (cherry picked from commit 9eff742da360953693d742002ed1636d7e5e5abe) --- .../azure/cosmos/_location_cache.py | 7 +- .../azure-cosmos/tests/test_location_cache.py | 83 +++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index cf8239488712..dd45279b6f59 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -513,8 +513,11 @@ def get_preferred_regional_routing_contexts( else: regional_endpoints.append(regional_endpoint) - # If all preferred locations are unavailable, honor the preferred list by trying them anyway. - if not regional_endpoints and unavailable_endpoints: + # Always append unavailable endpoints to the end of the list so they can be + # used as a last resort. This ensures that when all healthy endpoints are filtered + # out (e.g., by excluded_locations), the SDK can still fall back to unavailable + # regional endpoints rather than the global endpoint. + if unavailable_endpoints: regional_endpoints.extend(unavailable_endpoints) # If there are no preferred locations or none of the preferred locations are in the account, diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index 090717519222..84dce71d567d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -472,5 +472,88 @@ def test_write_fallback_to_global_after_regional_retries_exhausted(self): final_endpoint = lc.resolve_service_endpoint(write_request) assert final_endpoint == location1_endpoint + def test_unavailable_endpoints_not_dropped_from_routing_list(self): + """ + Unavailable endpoints should be appended to the end of the routing list, + not dropped entirely. + + Scenario: + - Customer has preferred_locations = ["East US", "West US 2"] + - East US is marked unavailable for writes + - Customer makes a request with excluded_locations = ["West US 2"] + - Expected: East US should still be available as fallback (unavailable but in the list) + """ + # Setup: Two preferred locations, multi-write enabled + preferred_locations = [location1_name, location2_name] + lc = refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + db_acc = create_database_account(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Verify initial state: Both locations are in write_regional_routing_contexts + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 2 + assert write_contexts[0].get_primary() == location1_endpoint + assert write_contexts[1].get_primary() == location2_endpoint + + # Mark location1 (East US) as unavailable for writes + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + # After marking unavailable, the routing list should still contain + # both endpoints - healthy ones first, unavailable ones at the end + write_contexts_after = lc.get_write_regional_routing_contexts() + assert len(write_contexts_after) == 2, \ + f"Expected 2 endpoints in routing list, got {len(write_contexts_after)}. " \ + "Unavailable endpoint was incorrectly dropped!" + # location2 (healthy) should be first + assert write_contexts_after[0].get_primary() == location2_endpoint + # location1 (unavailable) should be at the end as fallback + assert write_contexts_after[1].get_primary() == location1_endpoint + + # Now simulate the customer request with excluded_locations = ["location2"] + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = [location2_name] + + # Resolve endpoint - should get location1 (unavailable) as the only remaining option + # NOT the global default endpoint! + resolved_endpoint = lc.resolve_service_endpoint(write_request) + + # Should fall back to location1 (unavailable regional endpoint) + # NOT the global endpoint + assert resolved_endpoint == location1_endpoint, \ + f"Expected {location1_endpoint} but got {resolved_endpoint}. " \ + f"Bug: Unavailable endpoint was dropped and SDK fell back to global endpoint!" + + def test_unavailable_endpoints_ordering_in_routing_list(self): + """ + Test that healthy endpoints come before unavailable endpoints in the routing list. + This ensures the SDK tries healthy regions first, but has unavailable ones as fallback. + """ + # Setup: Three preferred locations + preferred_locations = [location1_name, location2_name, location3_name] + lc = refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + db_acc = create_database_account(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Mark location1 as unavailable + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + # Check ordering: location2, location3 (healthy) should come before location1 (unavailable) + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 3 + assert write_contexts[0].get_primary() == location2_endpoint # First healthy + assert write_contexts[1].get_primary() == location3_endpoint # Second healthy + assert write_contexts[2].get_primary() == location1_endpoint # Unavailable at end + + # Mark location2 as unavailable too + lc.mark_endpoint_unavailable_for_write(location2_endpoint, refresh_cache=True, context="test") + + # Check ordering: location3 (healthy) should come before location1, location2 (unavailable) + write_contexts = lc.get_write_regional_routing_contexts() + assert len(write_contexts) == 3 + assert write_contexts[0].get_primary() == location3_endpoint # Only healthy + # Unavailable ones at end, in original preferred order + assert write_contexts[1].get_primary() == location1_endpoint + assert write_contexts[2].get_primary() == location2_endpoint + if __name__ == "__main__": unittest.main() From cc412f07fa4e5d4b5ed0e25edec285d09269ca10 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Sun, 3 May 2026 15:27:12 -0500 Subject: [PATCH 02/18] Backport feed-range continuation and split-resume fixes for 4.14.7 --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 6 + sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 96 +- .../azure-cosmos/azure/cosmos/_constants.py | 1 + .../azure/cosmos/_cosmos_client_connection.py | 422 ++++- .../aio/base_execution_context.py | 14 +- .../base_execution_context.py | 17 +- .../azure/cosmos/_query_aggregate_utils.py | 268 +++ .../_routing/feed_range_continuation.py | 892 ++++++++++ .../aio/_cosmos_client_connection_async.py | 409 ++++- .../azure-cosmos/azure/cosmos/container.py | 1 + .../tests/test_crud_subpartition.py | 109 ++ .../tests/test_crud_subpartition_async.py | 109 ++ .../test_feed_range_continuation_token.py | 1523 +++++++++++++++++ .../tests/test_partition_split_retry_unit.py | 165 +- .../test_partition_split_retry_unit_async.py | 164 +- sdk/cosmos/azure-cosmos/tests/test_query.py | 395 +++++ .../azure-cosmos/tests/test_query_async.py | 396 +++++ .../test_query_feed_range_multipartition.py | 1005 +++++++++++ ...t_query_feed_range_multipartition_async.py | 811 +++++++++ 19 files changed, 6623 insertions(+), 180 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 8dd5d718b51a..c456fbeddadd 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,11 @@ ## Release History +### 4.14.7 (Unreleased) + +#### Bugs Fixed +* Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. +* Fixed bug where unavailable regional endpoints were dropped from the routing list instead of being kept as fallback options. See [PR 45200](https://github.com/Azure/azure-sdk-for-python/pull/45200) + ### 4.14.6 (2026-02-02) #### Bugs Fixed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index c20a879eff31..fab49b416c2c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -39,6 +39,11 @@ from . import documents from . import http_constants from . import _runtime_constants +from ._query_aggregate_utils import ( + _AggregatePartialClassification, + _classify_aggregate_partial, + _get_select_value_aggregate_function, +) from ._constants import _Constants as Constants from .auth import _get_authorization_header from .offer import ThroughputProperties @@ -124,6 +129,7 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]: options['accessCondition'] = {'type': 'IfNoneMatch', 'condition': if_none_match} return options + def _merge_query_results( results: dict[str, Any], partial_result: dict[str, Any], @@ -163,22 +169,13 @@ def _merge_query_results( results_docs = results.get("Documents") - # Check if both results are aggregate queries - is_partial_agg = ( - isinstance(partial_docs, list) - and len(partial_docs) == 1 - and isinstance(partial_docs[0], dict) - and partial_docs[0].get("_aggregate") is not None - ) - is_results_agg = ( - results_docs - and isinstance(results_docs, list) - and len(results_docs) == 1 - and isinstance(results_docs[0], dict) - and results_docs[0].get("_aggregate") is not None - ) + partial_aggregate_class = _classify_aggregate_partial(partial_docs, query) + results_aggregate_class = _classify_aggregate_partial(results_docs, query) - if is_partial_agg and is_results_agg: + if ( + partial_aggregate_class == _AggregatePartialClassification.OBJECT + and results_aggregate_class == _AggregatePartialClassification.OBJECT + ): agg_results = results_docs[0]["_aggregate"] # type: ignore[index] agg_partial = partial_docs[0]["_aggregate"] for key in agg_partial: @@ -196,33 +193,26 @@ def _merge_query_results( agg_results[key] += agg_partial[key] return results - # Check if both are VALUE aggregate queries - is_partial_value_agg = ( - isinstance(partial_docs, list) - and len(partial_docs) == 1 - and isinstance(partial_docs[0], (int, float)) - ) - is_results_value_agg = ( - results_docs - and isinstance(results_docs, list) - and len(results_docs) == 1 - and isinstance(results_docs[0], (int, float)) - ) - - if is_partial_value_agg and is_results_value_agg: - query_text = query.get("query") if isinstance(query, dict) else query - if query_text: - query_upper = query_text.upper() - # For MIN/MAX, we find the min/max of the partial results. - # For COUNT/SUM, we sum the partial results. - # Without robust query parsing, we can't distinguish them reliably. - # Defaulting to sum for COUNT/SUM. MIN/MAX VALUE queries are not fully supported client-side. - if " SELECT VALUE MIN" in query_upper: - results_docs[0] = min(results_docs[0], partial_docs[0]) # type: ignore[index] - elif " SELECT VALUE MAX" in query_upper: - results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] - else: # For COUNT/SUM, we sum the partial results - results_docs[0] += partial_docs[0] # type: ignore[index] + if ( + partial_aggregate_class == _AggregatePartialClassification.VALUE + and results_aggregate_class == _AggregatePartialClassification.VALUE + ): + aggregate_fn = _get_select_value_aggregate_function(query) + if aggregate_fn is None: + raise ValueError( + "Invariant violation: VALUE aggregate classification requires a recognized aggregate function." + ) + if aggregate_fn == "MIN": + results_docs[0] = min(results_docs[0], partial_docs[0]) # type: ignore[index] + elif aggregate_fn == "MAX": + results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] + elif aggregate_fn == "AVG": + raise ValueError( + "VALUE AVG aggregate merge across partitions is not supported client-side." + ) + else: + # COUNT/SUM are additive. + results_docs[0] += partial_docs[0] # type: ignore[index] return results # Standard query, append documents @@ -234,6 +224,28 @@ def _merge_query_results( return results +def _raise_query_merge_value_error(merge_error: ValueError) -> None: + """Raise a clearer user-facing error for unsupported VALUE aggregate merges. + + ``SELECT VALUE AVG(...)`` partials cannot be merged correctly client-side + across multiple partition/range responses. We fail loudly instead of + falling back to list concatenation (which would silently produce + mathematically incorrect results). + + :param merge_error: ValueError raised while merging partial query results. + :type merge_error: ValueError + :raises ValueError: Always re-raises, potentially with a clearer message. + """ + merge_message = str(merge_error) + if "VALUE AVG aggregate merge across partitions is not supported client-side." in merge_message: + raise ValueError( + "Unsupported query shape for range-scoped pagination: " + "SELECT VALUE AVG(...) cannot be merged client-side when the query " + "scope spans multiple physical partitions." + ) from merge_error + raise merge_error + + def GetHeaders( # pylint: disable=too-many-statements,too-many-branches cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], default_headers: Mapping[str, Any], diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index f49798c8f8b8..2efb813a498a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -66,6 +66,7 @@ class _Constants: AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default" INFERENCE_SERVICE_DEFAULT_SCOPE = "https://dbinference.azure.com/.default" SEMANTIC_RERANKER_INFERENCE_ENDPOINT: str = "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT" + EMIT_STRUCTURED_CONTINUATION_PK_CONFIG: str = "AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK" # Health Check Retry Policy constants AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index fa1f6efdced6..225f32889b1c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -23,11 +23,12 @@ """Document client class for the Azure Cosmos database service. """ +import logging import os import urllib.parse import uuid from concurrent.futures.thread import ThreadPoolExecutor -from typing import Callable, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast from typing_extensions import TypedDict from urllib3.util.retry import Retry @@ -69,10 +70,28 @@ from ._read_items_helper import ReadItemsHelperSync from ._request_object import RequestObject from ._retry_utility import ConnectionRetryPolicy -from ._routing import routing_map_provider, routing_range +from ._routing import routing_map_provider +from ._routing.feed_range_continuation import ( + _FeedRangePaginationState, + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _apply_feedrange_request_headers, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _hash_feed_range, + _hash_query_spec, + _increment_explode_iterations_or_raise, + _normalize_max_item_count, + _should_attempt_legacy_bridge_fallback, + _update_no_progress_page_count, + _validate_token_identity, + _write_query_outbound_continuation, +) from ._inference_service import _InferenceService from .documents import ConnectionPolicy, DatabaseAccount from .partition_key import ( + _build_partition_key_from_properties, _Undefined, _Empty, _PartitionKeyKind, @@ -81,6 +100,8 @@ _return_undefined_or_empty_partition_key, ) +_LOGGER = logging.getLogger(__name__) + class CredentialDict(TypedDict, total=False): masterKey: str resourceTokens: Mapping[str, Any] @@ -138,7 +159,12 @@ def __init__( # pylint: disable=too-many-statements """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection + self._emit_structured_continuation_pk = os.environ.get( + Constants.EMIT_STRUCTURED_CONTINUATION_PK_CONFIG, + "", + ).strip().lower() in ("1", "true", "yes", "on") self.master_key: Optional[str] = None + self.resource_tokens: Optional[Mapping[str, Any]] = None self.aad_credentials: Optional[TokenCredential] = None if auth is not None: @@ -3182,6 +3208,18 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma if timeout is not None: kwargs.setdefault("timeout", timeout) + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + + def _capture_internal_headers(headers: Mapping[str, Any]) -> None: + # Local helper so flow analysis can narrow Optional[Dict] once + # and every call site stays a single line. + if internal_headers_capture is None: + return + internal_headers_capture.clear() + internal_headers_capture.update(headers) + if query: __GetBodiesFromQueryResult = result_fn else: @@ -3227,12 +3265,13 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: change_feed_state.populate_request_headers(self._routing_map_provider, headers, feed_options) request_params.headers = headers - result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) - self.last_response_headers = last_response_headers + result, get_response_headers = self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = get_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(get_response_headers) if response_hook: - response_hook(last_response_headers, result) - return __GetBodiesFromQueryResult(result), last_response_headers - + response_hook(get_response_headers, result) + return __GetBodiesFromQueryResult(result), get_response_headers query = self.__CheckAndUnifyQueryFormat(query) if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, @@ -3264,8 +3303,11 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: req_headers[http_constants.HttpHeaders.IsQuery] = "true" base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id) - # Check if the over lapping ranges can be populated + # Check if the overlapping ranges can be populated feed_range_epk = None + container_properties = kwargs.pop("container_properties", None) + is_full_pk_structured_scope = False + legacy_partition_key_header = req_headers.get(http_constants.HttpHeaders.PartitionKey) if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() @@ -3274,73 +3316,319 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: prefix_partition_key_value: _SequentialPartitionKeyType = kwargs.pop("prefix_partition_key_value") feed_range_epk = ( prefix_partition_key_obj._get_epk_range_for_prefix_partition_key(prefix_partition_key_value)) + elif options.get("partitionKey") is not None and container_properties is not None: + partition_key_value = options["partitionKey"] + partition_key_obj = _build_partition_key_from_properties(container_properties) + if not partition_key_obj._is_prefix_partition_key(partition_key_value): + # Once we route full-PK queries through feed-range pagination, + # avoid sending the legacy partition-key header on the same request. + req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) + # Full-PK returns a single-value inclusive range; normalize to + # [min, max) before routing-map overlap resolution. + feed_range_epk = partition_key_obj._get_epk_range_for_partition_key( + partition_key_value + ).to_normalized_range() + is_full_pk_structured_scope = True # If feed_range_epk exist, query with the range if feed_range_epk is not None: - over_lapping_ranges = self._routing_map_provider.get_overlapping_ranges(resource_id, [feed_range_epk], - options) - # It is possible to get more than one over lapping range. We need to get the query results for each one - results: dict[str, Any] = {} - # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over - # lapping physical partition. The EPK sub range will be one of four: - # 1) Will have a range min equal to the feed range EPK min, and a range max equal to the over lapping - # partition - # 2) Will have a range min equal to the over lapping partition range min, and a range max equal to the - # feed range EPK range max. - # 3) will match exactly with the current over lapping physical partition, so we just return the over lapping - # physical partition's partition key id. - # 4) Will equal the feed range EPK since it is a sub range of a single physical partition - for over_lapping_range in over_lapping_ranges: - single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) - # Since the range min and max are all Upper Cased string Hex Values, - # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), - range_max=min(single_range.max, feed_range_epk.max), - isMinInclusive=True, isMaxInclusive=False) - if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: - # The Epk Sub Range spans exactly one physical partition - # In this case we can route to the physical pk range id - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] + if resource_id is None: + raise ValueError("resource_id is required for feed_range continuation.") + # The None-check above already narrows ``resource_id`` to ``str`` + # for the rest of this block. Bind it to a clearly-named local so + # the feed_range helpers below read as ``resource_id_str`` instead + # of the generic ``resource_id``. + resource_id_str: str = resource_id + # Decode and validate inbound continuation for this request. + # ``None`` means start from the beginning of the requested + # feed range. + page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) + query_hash = _hash_query_spec(query) + feedrange_hash = _hash_feed_range(feed_range_epk) + should_emit_structured_full_pk = self._emit_structured_continuation_pk + inbound_serialized_continuation = options.get("continuation") + inbound_token_payload = _decode_token(inbound_serialized_continuation) + legacy_bridge_in_use = False + legacy_fallback_attempted = False + if inbound_serialized_continuation and inbound_token_payload is None: + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + legacy_bridge_in_use = True else: - # The Epk Sub Range spans less than a single physical partition - # In this case we route to the physical partition and - # pass the epk sub range to the headers to filter within partition - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] - req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min - req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max - req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, last_response_headers = self.__Post( - path, request_params, query, req_headers, **kwargs + _LOGGER.warning( + "Feed-range query continuation token is not in the supported structured format; " + "restarting this feed_range query from the beginning." + ) + if inbound_token_payload is not None: + _validate_token_identity( + inbound_token_payload, + resource_id_str, + query, + feed_range_epk, + expected_query_hash=query_hash, + expected_feedrange_hash=feedrange_hash, ) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) - # Introducing a temporary complex function into a critical path to handle aggregated queries - # during splits, as a precaution falling back to the original logic if anything goes wrong - try: - results = base._merge_query_results(results, partial_result, query) - except Exception: # pylint: disable=broad-exception-caught - # If the new merge logic fails, fall back to the original logic. - if results: - results["Documents"].extend(partial_result["Documents"]) - else: - results = partial_result - if response_hook: - response_hook(last_response_headers, partial_result) - # if the prefix partition query has results lets return it - if results: - return __GetBodiesFromQueryResult(results), last_response_headers - - result, last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, result, last_response_headers) - if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + pagination_state = _FeedRangePaginationState.from_inbound( + inbound_token_payload, page_size_hint + ) + elif legacy_bridge_in_use and inbound_serialized_continuation: + pagination_state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + feed_range_epk, + inbound_serialized_continuation, + page_size_hint, + ) + else: + # First call. Ask the routing map which + # partitions the input feed_range overlaps right now and turn + # each overlap into a feedrange (intersection of that partition + # and the input feed_range). + first_overlaps = self._routing_map_provider.get_overlapping_ranges( + resource_id, [feed_range_epk], dict(options) + ) + all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) + if not all_feedranges: + # The input feed_range overlaps no current physical + # partition. Fall through to the regular __Post path + # below so cross-partition gating can still surface + # (e.g. an empty ``partition_key=[]`` prefix with + # ``enableCrossPartitionQuery=False`` must raise + # ``BAD_REQUEST``). + pagination_state = None + else: + pagination_state = _FeedRangePaginationState.from_derived_feedranges( + all_feedranges, + page_size_hint, + ) + + if pagination_state is not None: + results: dict[str, Any] = {} + feedrange_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + consecutive_no_progress_pages = 0 + + def _checkpoint_and_reraise(error: Exception) -> None: + # Intentionally broad: stamp the latest resumable checkpoint + # for any mid-page failure, then re-raise the original error. + self.last_response_headers = feedrange_response_headers + try: + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + should_emit_structured_full_pk, + query_hash, + feedrange_hash, + ) + except Exception as continuation_write_error: # pylint: disable=broad-exception-caught + _LOGGER.warning( + "Failed to write continuation while handling query POST failure: %s", + continuation_write_error, + ) + raise error + + # NOTE: Keep this feed_range pagination loop in sync with + # ``azure/cosmos/aio/_cosmos_client_connection_async.py::__QueryFeed``. + while pagination_state.can_issue_request(): + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. + overlapping = self._routing_map_provider.get_overlapping_ranges( + resource_id, [head_feedrange], dict(options) + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = self._routing_map_provider.get_overlapping_ranges( + resource_id, [head_feedrange], dict(options) + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + base.set_session_token_header( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) + + try: + backend_query_result, backend_response_headers = self.__Post( + path, request_params, query, req_headers, **kwargs + ) + except exceptions.CosmosHttpResponseError as post_error: + if ( + legacy_bridge_in_use + and not legacy_fallback_attempted + and _should_attempt_legacy_bridge_fallback(post_error) + ): + legacy_fallback_attempted = True + req_headers.pop(http_constants.HttpHeaders.PartitionKeyRangeID, None) + req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) + req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) + req_headers.pop(http_constants.HttpHeaders.ReadFeedKeyType, None) + if legacy_partition_key_header is not None: + req_headers[http_constants.HttpHeaders.PartitionKey] = legacy_partition_key_header + req_headers[http_constants.HttpHeaders.Continuation] = inbound_serialized_continuation + base.set_session_token_header( + self, req_headers, path, request_params, options, partition_key_range_id + ) + try: + backend_query_result, backend_response_headers = self.__Post( + path, request_params, query, req_headers, **kwargs + ) + except Exception as fallback_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(fallback_error) + self.last_response_headers = backend_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired( + req_headers, backend_query_result, backend_response_headers + ) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + return __GetBodiesFromQueryResult(backend_query_result), backend_response_headers + _checkpoint_and_reraise(post_error) + except Exception as post_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(post_error) + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, + ) + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + + previous_feedrange = pagination_state.head_range + previous_backend_continuation = pagination_state.head_bc + page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + pagination_state.apply_post_result( + page_items_returned, + backend_response_headers.get(http_constants.HttpHeaders.Continuation), + ) + consecutive_no_progress_pages = _update_no_progress_page_count( + consecutive_no_progress_pages, + page_items_returned, + previous_feedrange, + previous_backend_continuation, + pagination_state.head_range, + pagination_state.head_bc, + ) + if ( + consecutive_no_progress_pages >= _MAX_CONSECUTIVE_NO_PROGRESS_PAGES + and consecutive_no_progress_pages % _MAX_CONSECUTIVE_NO_PROGRESS_PAGES == 0 + ): + # Warning-only: do not fail fast here. + current_head = pagination_state.head_range + head_min = current_head.min if current_head else "" + head_max = current_head.max if current_head else "" + _LOGGER.warning( + "Feed-range query has returned 0 items for %s consecutive continuation pages " + "with the same continuation token and partition key range [%s, %s); continuing scan.", + consecutive_no_progress_pages, + head_min, + head_max, + ) + + # maxItemCount is a per-request hint. Return this SDK page + # after the first non-empty logical result instead of filling + # an exact target count by issuing extra backend requests. + if page_items_returned > 0: + break + + # Pagination loop is done — write the final outbound + # continuation (or clear the header if the queue is fully + # drained) so the caller's ``by_page`` loop terminates. + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + should_emit_structured_full_pk, + query_hash, + feedrange_hash, + ) + # End feed_range pagination block. + self.last_response_headers = feedrange_response_headers + + # if the prefix partition query has results lets return it + if results: + if feedrange_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + index_metrics_raw = feedrange_response_headers[http_constants.HttpHeaders.IndexUtilization] + feedrange_response_headers[http_constants.HttpHeaders.IndexUtilization] = ( + _utils.get_index_metrics_info(index_metrics_raw)) + return __GetBodiesFromQueryResult(results), feedrange_response_headers + return [], feedrange_response_headers + + result, post_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = post_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(post_response_headers) + self._UpdateSessionIfRequired(req_headers, result, post_response_headers) + if post_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization - index_metrics_raw = last_response_headers[INDEX_METRICS_HEADER] - last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + index_metrics_raw = post_response_headers[INDEX_METRICS_HEADER] + post_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) if response_hook: - response_hook(last_response_headers, result) + response_hook(post_response_headers, result) - return __GetBodiesFromQueryResult(result), last_response_headers + return __GetBodiesFromQueryResult(result), post_response_headers def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_locations: Optional[Sequence[str]] = None, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py index 1a866c2df873..fd356c94d77c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py @@ -51,6 +51,10 @@ def __init__(self, client, options): self._has_started = False self._has_finished = False self._buffer = deque() + self._resource_link = None + # Per-query mutable capture used by __QueryFeed to report response + # headers (including failure checkpoints) without crossing requests. + self._internal_response_headers_capture = {} def _get_initial_continuation(self): if "continuation" in self._options: @@ -118,6 +122,9 @@ async def _fetch_items_helper_no_retries(self, fetch_function): """ fetched_items = [] new_options = copy.deepcopy(self._options) + # Clear stale values from prior pages before issuing a new fetch. + self._internal_response_headers_capture.clear() + new_options["_internal_response_headers_capture"] = self._internal_response_headers_capture while self._continuation or not self._has_started: new_options["continuation"] = self._continuation @@ -185,8 +192,13 @@ async def callback(**kwargs): # pylint: disable=unused-argument # Refresh routing map to get new partition key ranges self._client.refresh_routing_map_provider() # Reset execution context state to allow retry from the beginning + + # Reset execution context state for retry. If __QueryFeed already + # stamped a checkpoint continuation on failure, resume from it. + continuation_key = http_constants.HttpHeaders.Continuation + checkpoint_continuation = self._internal_response_headers_capture.get(continuation_key) self._has_started = False - self._continuation = None + self._continuation = checkpoint_continuation # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py index 528ca87f2586..a38be4b21df6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py @@ -49,6 +49,10 @@ def __init__(self, client, options): self._has_started = False self._has_finished = False self._buffer = deque() + self._resource_link = None + # Per-query mutable capture used by __QueryFeed to report response + # headers (including failure checkpoints) without crossing requests. + self._internal_response_headers_capture = {} def _get_initial_continuation(self): if "continuation" in self._options: @@ -116,6 +120,9 @@ def _fetch_items_helper_no_retries(self, fetch_function): """ fetched_items = [] new_options = copy.deepcopy(self._options) + # Clear stale values from prior pages before issuing a new fetch. + self._internal_response_headers_capture.clear() + new_options["_internal_response_headers_capture"] = self._internal_response_headers_capture while self._continuation or not self._has_started: new_options["continuation"] = self._continuation @@ -180,11 +187,15 @@ def callback(**kwargs): # pylint: disable=unused-argument max_retries ) - # Refresh routing map to get new partition key ranges + + # Refresh routing map to get new partition key ranges. self._client.refresh_routing_map_provider() - # Reset execution context state to allow retry from the beginning + # Reset execution context state for retry. If __QueryFeed already + # stamped a checkpoint continuation on failure, resume from it. + continuation_key = http_constants.HttpHeaders.Continuation + checkpoint_continuation = self._internal_response_headers_capture.get(continuation_key) self._has_started = False - self._continuation = None + self._continuation = checkpoint_continuation # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py new file mode 100644 index 000000000000..d688188f6fdb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py @@ -0,0 +1,268 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +from enum import Enum +from typing import Any, Optional, Union + + +# Used by query paging and query merge paths to decide whether a row is +# a normal row or part of an aggregate result. +class _AggregatePartialClassification(Enum): + """Classification for one-partition query partial payloads.""" + + NONE = "none" + OBJECT = "object" + VALUE = "value" + + +def _extract_query_text(query: Optional[Union[str, dict[str, Any]]]) -> Optional[str]: + """Extract SQL text from a string or query-spec dictionary. + + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: Query text when present; otherwise ``None``. + :rtype: Optional[str] + """ + if isinstance(query, str): + return query + if isinstance(query, dict): + query_text = query.get("query") + if isinstance(query_text, str): + return query_text + return None + + +def _strip_sql_block_comments(query_text: str) -> str: + """Return ``query_text`` with ``/* ... */`` comment spans removed. + + The aggregate detector is a lightweight scanner, so this helper keeps the + same lightweight approach and removes only block comments before scanning. + Quoted strings are preserved so comment-like text inside literals does not + get stripped. + + :param query_text: Raw query text. + :type query_text: str + :returns: Query text with block comments removed. + :rtype: str + """ + out: list[str] = [] + index = 0 + length = len(query_text) + in_quote: Optional[str] = None + + while index < length: + ch = query_text[index] + + if in_quote is not None: + out.append(ch) + # SQL-style escaped quote inside same quote type, e.g. 'it''s'. + if ch == in_quote and index + 1 < length and query_text[index + 1] == in_quote: + out.append(query_text[index + 1]) + index += 2 + continue + if ch == in_quote: + in_quote = None + index += 1 + continue + + if ch in ("'", '"'): + in_quote = ch + out.append(ch) + index += 1 + continue + + if ch == "/" and index + 1 < length and query_text[index + 1] == "*": + index += 2 + while index + 1 < length and not (query_text[index] == "*" and query_text[index + 1] == "/"): + index += 1 + if index + 1 < length: + index += 2 + # Preserve token separation where a comment was removed. + out.append(" ") + continue + + out.append(ch) + index += 1 + + return "".join(out) + + +def _get_select_value_aggregate_function(query: Optional[Union[str, dict[str, Any]]]) -> Optional[str]: + """Identify the aggregate function for ``SELECT VALUE`` aggregate queries. + + This is a lightweight text heuristic (not a SQL parser). It extracts only + the OUTER ``SELECT VALUE`` projection and then matches aggregate function + names in that projection so nested subqueries do not drive outer + classification. + + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: One of ``COUNT``, ``SUM``, ``MIN``, ``MAX``, ``AVG`` when matched; otherwise ``None``. + :rtype: Optional[str] + """ + query_text = _extract_query_text(query) + if not query_text: + return None + + without_comments = _strip_sql_block_comments(query_text) + normalized = " ".join(without_comments.upper().split()) + projection = _extract_outer_select_value_projection(normalized) + if projection is None: + return None + + projection = _unwrap_outer_parentheses(projection) + # A projection-level subquery should not classify as an outer VALUE aggregate. + if projection.startswith("SELECT VALUE "): + return None + + return _find_top_level_aggregate_function(projection) + + +def _find_top_level_aggregate_function(projection: str) -> Optional[str]: + """Return an aggregate function name only when it appears at the top level. + + This prevents nested projection expressions (for example ARRAY(SELECT VALUE + COUNT(...))) from being misclassified as outer VALUE aggregates. + + :param projection: SELECT VALUE projection text to inspect. + :type projection: str + :returns: Aggregate function name when matched at top level; otherwise ``None``. + :rtype: Optional[str] + """ + aggregate_fns = {"COUNT", "SUM", "MIN", "MAX", "AVG"} + depth = 0 + index = 0 + length = len(projection) + + while index < length: + ch = projection[index] + if ch == "(": + depth += 1 + index += 1 + continue + if ch == ")": + if depth > 0: + depth -= 1 + index += 1 + continue + + if depth == 0 and (ch.isalpha() or ch == "_"): + start = index + index += 1 + while index < length and (projection[index].isalnum() or projection[index] == "_"): + index += 1 + token = projection[start:index] + + if token in aggregate_fns: + lookahead = index + while lookahead < length and projection[lookahead].isspace(): + lookahead += 1 + if lookahead < length and projection[lookahead] == "(": + return token + continue + + index += 1 + + return None + + +def _unwrap_outer_parentheses(text: str) -> str: + """Strip redundant outer parentheses while preserving inner structure. + + :param text: Projection text to normalize. + :type text: str + :returns: Projection text with only redundant outer parentheses removed. + :rtype: str + """ + candidate = text.strip() + while candidate.startswith("(") and candidate.endswith(")"): + depth = 0 + balanced = True + outer_pair = False + for idx, char in enumerate(candidate): + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + if depth < 0: + balanced = False + break + # Closing the opening '(' at index 0 means we found the outer pair. + if depth == 0: + outer_pair = idx == len(candidate) - 1 + break + if not balanced or not outer_pair: + break + candidate = candidate[1:-1].strip() + return candidate + + +def _extract_outer_select_value_projection(normalized_query: str) -> Optional[str]: + """Return the outer ``SELECT VALUE`` projection text up to the outer ``FROM``. + + Uses a lightweight parenthesis-depth scan so nested subqueries do not + influence outer aggregate detection. + + :param normalized_query: Uppercased, whitespace-normalized query text. + :type normalized_query: str + :returns: Outer ``SELECT VALUE`` projection when found; otherwise ``None``. + :rtype: Optional[str] + """ + select_value = "SELECT VALUE" + start_idx = normalized_query.find(select_value) + if start_idx < 0: + return None + + projection_start = start_idx + len(select_value) + if projection_start < len(normalized_query) and normalized_query[projection_start] == " ": + projection_start += 1 + + depth = 0 + index = projection_start + while index <= len(normalized_query) - 4: + ch = normalized_query[index] + if ch == "(": + depth += 1 + elif ch == ")" and depth > 0: + depth -= 1 + + if depth == 0 and normalized_query[index:index + 4] == "FROM": + prev_char = normalized_query[index - 1] if index > 0 else " " + next_char = normalized_query[index + 4] if index + 4 < len(normalized_query) else " " + if not (prev_char.isalnum() or prev_char == "_") and not (next_char.isalnum() or next_char == "_"): + projection = normalized_query[projection_start:index].strip() + return projection or None + index += 1 + + return None + + +def _classify_aggregate_partial( + docs: Any, + query: Optional[Union[str, dict[str, Any]]] +) -> _AggregatePartialClassification: + """Classify whether a partial result row is part of an aggregate result. + + :param docs: Partial ``Documents`` payload from one backend response. + :type docs: Any + :param query: Query text or query spec dictionary. + :type query: Optional[Union[str, dict[str, Any]]] + :returns: Aggregate partial classification. + :rtype: _AggregatePartialClassification + """ + if not isinstance(docs, list) or len(docs) != 1: + return _AggregatePartialClassification.NONE + + row = docs[0] + if isinstance(row, dict) and row.get("_aggregate") is not None: + return _AggregatePartialClassification.OBJECT + + # bool is intentionally excluded: VALUE-aggregate merge semantics are numeric. + if isinstance(row, (int, float)) and not isinstance(row, bool): + if _get_select_value_aggregate_function(query) is not None: + return _AggregatePartialClassification.VALUE + + return _AggregatePartialClassification.NONE diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py new file mode 100644 index 000000000000..2f55f2ea5cd1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py @@ -0,0 +1,892 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Shared helpers for the structured ``feed_range`` continuation token. + +Both sync and async ``__QueryFeed`` implementations use this module for +token wire format, request fingerprinting, and feed-range routing helpers. + +The token stores an ordered ``c`` list of ``{min, max, bc}`` entries. +Pagination reads and updates the queue head, then advances when the head +is drained. +""" + +import base64 +import binascii +import json +from collections import deque +from typing import Any, Deque, Iterable, List, MutableMapping, Optional, Tuple + +from .. import http_constants +from .._cosmos_integers import _UInt128 +from .._cosmos_murmurhash3 import murmurhash3_128 +from .._query_aggregate_utils import _AggregatePartialClassification, _classify_aggregate_partial +from . import routing_range + + +# ----- Token wire-format constants --------------------------------------- +# Field codes for the v=1 envelope. +_TOKEN_VERSION = 1 +# Token schema version so decoders can reject unknown envelope shapes. +_FIELD_VERSION = "v" +# Resource ID for the container that originally produced this token. +_FIELD_COLLECTION_RID = "cr" +# Fingerprint of query text + parameter values to prevent wrong-query resume. +_FIELD_QUERY_HASH = "qh" +# Fingerprint of the caller's input feed_range to prevent wrong-scope resume. +_FIELD_FEEDRANGE_HASH = "frh" +# Ordered list of {min, max, bc} entries for the requested feed range. +# Iteration state comes from the list order; there is no separate +# top-level "current" field. +_FIELD_CONTINUATIONS = "c" +# Backend continuation for ONE entry. Lives INSIDE each ``c[i]`` entry, +# never at the envelope level. ``null`` means "this sub-range has not +# been started, or has been fully drained". +_FIELD_BACKEND_CONTINUATION = "bc" +# Observability threshold for repeated empty pages with no continuation/feedrange movement. +# This is warning-only (not a hard stop); pagination continues until the queue drains. +_MAX_CONSECUTIVE_NO_PROGRESS_PAGES = 1000 +# Safety guard for pathological split re-resolution loops. +_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS = 50 + + +# ----- Hash helpers ------------------------------------------------------ +def _stable_hash_128(payload: bytes) -> str: + """Stable 128-bit hex digest of ``payload``. + + Uses ``MurmurHash3_128`` (the same helper ``partition_key.py`` uses + for EPK routing). The fingerprint is non-cryptographic and used + only for an equality check inside ``_decode_token``: on resume the + SDK recomputes the same hash from the live call's inputs and + raises if it does not match the value baked into the saved token. + A cryptographic hash buys nothing here because the field is never + sent to the service and is never used as proof of input. + + :param payload: Bytes to hash. + :type payload: bytes + :returns: A 32-character hexadecimal digest. + :rtype: str + """ + return murmurhash3_128(bytearray(payload), _UInt128(0, 0)).as_hex() + + +def _hash_query_spec(query: Any) -> str: + """Hash query text + (parameter name, JSON-canonical value) pairs. + + Resume requires the exact same query shape, not a semantically + equivalent one. ``query`` may be either a string or the dict form + produced by ``__CheckAndUnifyQueryFormat``. + + :param query: Query text or query spec dictionary. + :type query: str or dict + :returns: Stable hash for query text and parameters. + :rtype: str + """ + parameters: list = [] + parts: List[bytes] = [] + if isinstance(query, dict): + parts.append((query.get("query") or "").encode("utf-8")) + parameters = query.get("parameters") or [] + else: + parts.append((query or "").encode("utf-8")) + parts.append(b"\0") + for p in parameters: + parts.append((p.get("name", "") or "").encode("utf-8")) + parts.append(b"\0") + parts.append( + json.dumps(p.get("value"), sort_keys=True, separators=(",", ":")).encode("utf-8") + ) + parts.append(b"\0") + return _stable_hash_128(b"".join(parts)) + + +def _hash_feed_range(feed_range: routing_range.Range) -> str: + """Stable 128-bit fingerprint of the INPUT feed_range. + + Detects a token that was created against a different feed_range on + the same container being replayed against the wrong scope. + + The input is first converted to a standard ``[min, max)`` form via + ``Range.to_normalized_range()`` (idempotent — returns ``self`` when + already normalized). The canonical JSON intentionally carries only + ``min`` and ``max``: under the normalized form the inclusivity + flags are constants (``True``/``False``), so hashing them adds no + signal and would only mask the fact that the fingerprint identifies + the *logical EPK interval*, not the on-the-wire representation of + the bounds. Two ``Range`` objects describing the same logical + interval (e.g. ``[A, B)`` and the equivalent ``(A-1, B-1]``) hash + equal. + + :param feed_range: Input feed range. + :type feed_range: ~azure.cosmos._routing.routing_range.Range + :returns: Stable feed range fingerprint. + :rtype: str + """ + normalized = feed_range.to_normalized_range() + canonical = json.dumps( + {"min": normalized.min, "max": normalized.max}, + sort_keys=True, + separators=(",", ":"), + ) + return _stable_hash_128(canonical.encode("utf-8")) + + +# ----- Token codec ------------------------------------------------------- +def _encode_token(payload: dict) -> str: + """JSON-serialize ``payload`` then base64-encode to a single ASCII blob. + + :param payload: Token envelope to serialize. + :type payload: dict + :returns: Base64-encoded token string. + :rtype: str + """ + return base64.b64encode( + json.dumps(payload, separators=(",", ":")).encode("utf-8") + ).decode("ascii") + + +def _decode_token(serialized: Optional[str]) -> Optional[dict]: + """Decode a continuation string into our token dict, or ``None``. + + Returns ``None`` when ``serialized`` is empty or not in our shape. + + Raises ``ValueError`` only when the input parses as our shape but is + structurally invalid (for example unknown ``v`` or missing fields). + + :param serialized: Encoded continuation token from the caller. + :type serialized: Optional[str] + :returns: Decoded token payload when valid; otherwise ``None``. + :rtype: Optional[dict] + """ + if not serialized: + return None + try: + decoded_bytes = base64.b64decode(serialized, validate=True) + decoded = json.loads(decoded_bytes.decode("utf-8")) + except (ValueError, TypeError, UnicodeDecodeError, binascii.Error): + return None # not our shape -> start fresh + if not isinstance(decoded, dict) or _FIELD_VERSION not in decoded: + return None + version = decoded.get(_FIELD_VERSION) + if version != _TOKEN_VERSION: + raise ValueError( + "Unsupported feed_range continuation token version: {}. " + "This SDK supports version {}.".format(version, _TOKEN_VERSION) + ) + _validate_v1_token_structure(decoded) + return decoded + + +def _validate_v1_token_structure(decoded: dict) -> None: + """Validate required v1 token fields so downstream code can index + them without checking for ``KeyError``. + + :param decoded: Decoded token payload to validate. + :type decoded: dict + """ + if not isinstance(decoded.get(_FIELD_COLLECTION_RID), str): + raise ValueError("Malformed feed_range continuation token: 'cr' is required.") + if not isinstance(decoded.get(_FIELD_QUERY_HASH), str): + raise ValueError("Malformed feed_range continuation token: 'qh' is required.") + if not isinstance(decoded.get(_FIELD_FEEDRANGE_HASH), str): + raise ValueError("Malformed feed_range continuation token: 'frh' is required.") + # ``bc`` must be per-entry inside ``c[i]``; top-level ``bc`` is invalid. + if _FIELD_BACKEND_CONTINUATION in decoded: + raise ValueError( + "Malformed feed_range continuation token: top-level 'bc' is not " + "supported; 'bc' must live inside each 'c' entry." + ) + + entries = decoded.get(_FIELD_CONTINUATIONS) + if not isinstance(entries, list) or not entries: + # Producers clear the continuation header when drained, so + # an on-wire token must contain at least one entry. + raise ValueError( + "Malformed feed_range continuation token: '{}' is required and " + "must be a non-empty list.".format(_FIELD_CONTINUATIONS) + ) + for idx, entry in enumerate(entries): + if not isinstance(entry, dict): + raise ValueError( + "Malformed feed_range continuation token: '{}[{}]' must be an object.".format( + _FIELD_CONTINUATIONS, idx + ) + ) + _validate_range_dict(entry, "{}[{}]".format(_FIELD_CONTINUATIONS, idx)) + + +def _validate_range_dict(range_dict: dict, field_name: str) -> None: + """Each persisted feedrange is a {'min': str, 'max': str, 'bc': str|null} dict. + + :param range_dict: Serialized feed range dictionary. + :type range_dict: dict + :param field_name: Field label used in validation messages. + :type field_name: str + """ + if not isinstance(range_dict.get("min"), str) or not isinstance(range_dict.get("max"), str): + raise ValueError( + "Malformed feed_range continuation token: '{}' and '{}' are required.".format( + f"{field_name}.min", f"{field_name}.max" + ) + ) + if _FIELD_BACKEND_CONTINUATION not in range_dict: + raise ValueError( + "Malformed feed_range continuation token: '{}.bc' is required (use null when absent).".format( + field_name + ) + ) + bc_value = range_dict[_FIELD_BACKEND_CONTINUATION] + if bc_value is not None and not isinstance(bc_value, str): + raise ValueError( + "Malformed feed_range continuation token: '{}.bc' must be a string or null.".format( + field_name + ) + ) + + +# ----- Feedrange / routing helpers --------------------------------------- +def _dict_to_range(range_dict: dict) -> routing_range.Range: + """Convert a persisted ``{'min': ..., 'max': ...}`` dict back into a ``Range``. + + :param range_dict: Persisted feed range dictionary. + :type range_dict: dict + :returns: Routing range instance. + :rtype: ~azure.cosmos._routing.routing_range.Range + """ + return routing_range.Range( + range_min=range_dict["min"], + range_max=range_dict["max"], + isMinInclusive=True, + isMaxInclusive=False, + ) + + +def _validate_token_identity( + inbound: dict, + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + expected_query_hash: Optional[str] = None, + expected_feedrange_hash: Optional[str] = None, +) -> None: + """Confirm the inbound token was created for the same collection, + query, and feed_range the current call is using. If any of the + three fingerprints disagrees, raise ``ValueError`` so the caller + finds out instead of silently getting rows from a different + request. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :param resource_id: Current collection resource ID. + :type resource_id: str + :param query: Current query spec. + :type query: str or dict + :param feed_range_epk: Current feed range scope. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param expected_query_hash: Precomputed query hash to validate against inbound token. + :type expected_query_hash: Optional[str] + :param expected_feedrange_hash: Precomputed feed_range hash to validate against inbound token. + :type expected_feedrange_hash: Optional[str] + """ + expected_qh = expected_query_hash or _hash_query_spec(query) + expected_frh = expected_feedrange_hash or _hash_feed_range(feed_range_epk) + if inbound[_FIELD_COLLECTION_RID] != resource_id: + raise ValueError( + "Continuation token was created for a different collection " + "(collection rid mismatch)." + ) + if inbound[_FIELD_QUERY_HASH] != expected_qh: + raise ValueError( + "Continuation token was created with a different query " + "(query hash mismatch). Resume requires the exact same query shape." + ) + if inbound[_FIELD_FEEDRANGE_HASH] != expected_frh: + raise ValueError( + "Continuation token was created for a different feed_range " + "(feed_range hash mismatch)." + ) + + +def _extract_resume_queue( + inbound: dict, +) -> List[Tuple[routing_range.Range, Optional[str]]]: + """Decode the ``c`` list into an ordered list of ``(range, bc)`` pairs. + + The wire format stores a single ordered ``c`` list of + ``{min, max, bc}`` entries. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :returns: Ordered list of ``(range, backend_continuation)`` pairs. + :rtype: list[tuple[~azure.cosmos._routing.routing_range.Range, Optional[str]]] + """ + return [ + (_dict_to_range(entry), entry.get(_FIELD_BACKEND_CONTINUATION)) + for entry in inbound[_FIELD_CONTINUATIONS] + ] + + +def _build_scope_from_overlaps( + overlapping: List[dict], feedrange: routing_range.Range +) -> Tuple[List[dict], routing_range.Range]: + """Compute the smallest EPK ``Range`` that covers every one of the + overlapping physical partitions, and return both the original + overlaps and that combined range. + + Both the sync and async pagination paths call this directly after + awaiting / invoking ``routing_map_provider.get_overlapping_ranges`` + themselves, so the live lookup stays at the call site (sync vs. + async) and the pure combine logic is shared here. + + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :param feedrange: Feed range used for error context. + :type feedrange: ~azure.cosmos._routing.routing_range.Range + :returns: Original overlaps and the combined range covering them. + :rtype: tuple[list[dict], ~azure.cosmos._routing.routing_range.Range] + """ + if not overlapping: + raise RuntimeError( + "Routing map returned no overlapping ranges for feedrange " + "[{}, {}).".format(feedrange.min, feedrange.max) + ) + min_inclusive = overlapping[0]["minInclusive"] + max_exclusive = overlapping[0]["maxExclusive"] + for overlap_range in overlapping[1:]: + if overlap_range["minInclusive"] < min_inclusive: + min_inclusive = overlap_range["minInclusive"] + if overlap_range["maxExclusive"] > max_exclusive: + max_exclusive = overlap_range["maxExclusive"] + scope = routing_range.Range( + range_min=min_inclusive, + range_max=max_exclusive, + isMinInclusive=True, + isMaxInclusive=False, + ) + return overlapping, scope + + +def _derive_initial_feedranges( + feed_range_epk: routing_range.Range, overlapping: List[dict] +) -> List[routing_range.Range]: + """Given the caller's input feed_range and the partitions it + currently overlaps, return one sub-feedrange per partition (the + intersection of the partition's range and the input feed_range), + ordered by EPK ``min``. + + :param feed_range_epk: Requested feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :returns: Derived feed ranges ordered by ``min``. + :rtype: list[~azure.cosmos._routing.routing_range.Range] + """ + feedranges: List[routing_range.Range] = [] + for overlap_range in overlapping: + partition_range = routing_range.Range.PartitionKeyRangeToRange(overlap_range) + feedranges.append( + routing_range.Range( + range_min=max(partition_range.min, feed_range_epk.min), + range_max=min(partition_range.max, feed_range_epk.max), + isMinInclusive=True, + isMaxInclusive=False, + ) + ) + feedranges.sort(key=lambda feedrange_range: feedrange_range.min) + return feedranges + + +class _FeedRangePaginationState: + """Tracks where a feed_range query is up to between page calls. + + Holds a single ordered queue of ``(sub-range, backend continuation)`` + pairs. The pagination loop: + + * peeks the queue head to learn the next sub-range to POST and + the backend continuation (if any) to send with it, + * updates the head's backend continuation when the backend + returns a non-null one, + * pops the head when the sub-range is drained, + * on a partition split, replaces the head with one entry per + child sub-range (each inheriting the parent's backend continuation). + + There is no separate "current vs. remaining" split. The head is + ``queue[0]`` and later entries are queued behind it. + + Split-child insertion is tail-based so existing queued ranges + remain ahead of newly discovered children. + + Not thread-safe. One instance is created per ``query_items`` call + and is mutated only by that call's pagination loop (sync or async) + — never shared across threads or concurrent tasks. + """ + + def __init__( + self, + queue: Iterable[Tuple[routing_range.Range, Optional[str]]], + page_size_hint: Optional[int], + ) -> None: + self.queue: Deque[Tuple[routing_range.Range, Optional[str]]] = deque(queue) + self.page_size_hint = page_size_hint + + @classmethod + def from_inbound( + cls, + inbound: dict, + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state from a decoded inbound token. + + :param inbound: Decoded inbound token payload. + :type inbound: dict + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized for resume. + :rtype: _FeedRangePaginationState + """ + return cls(_extract_resume_queue(inbound), page_size_hint) + + @classmethod + def from_derived_feedranges( + cls, + feedranges: Iterable[routing_range.Range], + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state from feedranges computed at startup (no backend + continuations yet — every entry starts with ``bc = None``). + + :param feedranges: Derived feedranges ordered by ``min``. + :type feedranges: Iterable[~azure.cosmos._routing.routing_range.Range] + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized for first request. + :rtype: _FeedRangePaginationState + """ + return cls(((fr, None) for fr in feedranges), page_size_hint) + + @classmethod + def from_single_feedrange_with_continuation( + cls, + feedrange: routing_range.Range, + backend_continuation: Optional[str], + page_size_hint: Optional[int], + ) -> "_FeedRangePaginationState": + """Build state for one feedrange where a backend continuation + already exists. + + Used for legacy-token compatibility on full-PK queries: + we keep the decoder strict, then bridge a legacy continuation + string into the queue head's ``bc`` slot for the single target + range. + + :param feedrange: Single feedrange to seed. + :type feedrange: ~azure.cosmos._routing.routing_range.Range + :param backend_continuation: Existing backend continuation for the range. + :type backend_continuation: Optional[str] + :param page_size_hint: Request page-size hint propagated to backend POSTs. + :type page_size_hint: Optional[int] + :returns: Pagination state initialized with one queued entry. + :rtype: _FeedRangePaginationState + """ + return cls(((feedrange, backend_continuation),), page_size_hint) + + @property + def head_range(self) -> Optional[routing_range.Range]: + """The sub-range at the head of the queue (the one the next + backend POST will target), or ``None`` when the queue is drained. + """ + return self.queue[0][0] if self.queue else None + + @property + def head_bc(self) -> Optional[str]: + """Backend continuation paired with the head sub-range, or + ``None`` if the head has not been started yet (or has nothing + more to fetch). + """ + return self.queue[0][1] if self.queue else None + + def can_issue_request(self) -> bool: + """Whether another backend POST can be issued for this page. + + :returns: ``True`` when the queue is non-empty. + :rtype: bool + """ + return bool(self.queue) + + def explode_on_multi_overlap(self, overlapping: List[dict]) -> bool: + """If the head sub-range now spans more than one physical + partition (Cosmos split it since the token was minted), + replace the head with one entry per child sub-range and carry + the parent backend continuation onto each child. + + Dequeue the parent and append child entries at the tail + (preserving child EPK order). Each child inherits the + parent ``bc`` so resume can continue after a split without + replaying the entire child feed range. + + :param overlapping: Routing overlaps for the head sub-range. + :type overlapping: list[dict] + :returns: ``True`` when the head was split into multiple children. + :rtype: bool + """ + if not self.queue or len(overlapping) <= 1: + return False + head_range, parent_bc = self.queue[0] + sub_feedranges = _derive_initial_feedranges(head_range, overlapping) + if not sub_feedranges: + return False + self.queue.popleft() + # Keep existing tail entries ahead of split children. + for sub in sub_feedranges: + self.queue.append((sub, parent_bc)) + return True + + def apply_post_result(self, items_returned: int, backend_continuation: Optional[str]) -> None: + """Apply one backend response to the queue. + + :param items_returned: Number of logical rows returned by this POST. + :type items_returned: int + :param backend_continuation: Backend continuation for the head + sub-range (``None`` when the head is drained). + :type backend_continuation: Optional[str] + """ + # Kept for call-site API symmetry and observability; page-size hints are + # no longer decremented between backend requests. + _ = items_returned + if not self.queue: + return + head_range, _ = self.queue[0] + if backend_continuation is not None: + # Update head's bc in place; head sub-range itself is unchanged. + self.queue[0] = (head_range, backend_continuation) + else: + # Head sub-range fully drained; advance to next entry. + self.queue.popleft() + + def write_outbound_continuation( + self, + last_response_headers: MutableMapping[str, Any], + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + query_hash: Optional[str] = None, + feedrange_hash: Optional[str] = None, + ) -> None: + """Set or clear the outbound continuation header from the queue. + + Empty queue means the pagination loop ran out of sub-ranges; the + header is removed and the caller's ``by_page`` loop terminates. + Otherwise the entire queue is serialized as a fresh v=1 envelope + via ``_build_outbound_token``. + + :param last_response_headers: Response headers to mutate. + :type last_response_headers: MutableMapping[str, Any] + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query spec used for hashing. + :type query: str or dict + :param feed_range_epk: Original request feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param query_hash: Optional precomputed query hash to embed in the outbound token. + :type query_hash: Optional[str] + :param feedrange_hash: Optional precomputed feed_range hash to embed in the outbound token. + :type feedrange_hash: Optional[str] + """ + if not self.queue: + last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) + return + last_response_headers[http_constants.HttpHeaders.Continuation] = _build_outbound_token( + resource_id, + query, + feed_range_epk, + self.queue, + query_hash=query_hash, + feedrange_hash=feedrange_hash, + ) + + +def _write_query_outbound_continuation( + last_response_headers: MutableMapping[str, Any], + pagination_state: _FeedRangePaginationState, + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + is_full_pk_structured_scope: bool, + should_emit_structured_full_pk: bool, + query_hash: str, + feedrange_hash: str, +) -> None: + """Write outbound continuation for feed-range pagination. + + Full-PK queries keep legacy continuation emission unless structured + emission is explicitly enabled by the client-level env-var contract. + + :param last_response_headers: Response headers to mutate. + :type last_response_headers: MutableMapping[str, Any] + :param pagination_state: Current pagination state for this request. + :type pagination_state: _FeedRangePaginationState + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query text/spec used for hash identity. + :type query: Any + :param feed_range_epk: Original request feed range. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param is_full_pk_structured_scope: Whether request scope is full-PK on structured path. + :type is_full_pk_structured_scope: bool + :param should_emit_structured_full_pk: Whether structured emission is enabled for full-PK. + :type should_emit_structured_full_pk: bool + :param query_hash: Precomputed query hash for outbound token identity. + :type query_hash: str + :param feedrange_hash: Precomputed feed range hash for outbound token identity. + :type feedrange_hash: str + :returns: None. Mutates ``last_response_headers`` in place. + :rtype: None + """ + if is_full_pk_structured_scope and not should_emit_structured_full_pk: + legacy_outbound = pagination_state.head_bc + if legacy_outbound is None: + last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) + else: + last_response_headers[http_constants.HttpHeaders.Continuation] = legacy_outbound + return + pagination_state.write_outbound_continuation( + last_response_headers, + resource_id, + query, + feed_range_epk, + query_hash=query_hash, + feedrange_hash=feedrange_hash, + ) + + +def _should_attempt_legacy_bridge_fallback(error: Any) -> bool: + """Return whether a compatibility fallback should be attempted. + + Compatibility fallback is restricted to legacy-token bridge failures + that surface as ``400 BadRequest``. + + :param error: Exception raised by backend request execution. + :type error: Any + :returns: ``True`` when the error is a ``400 BadRequest`` compatibility failure. + :rtype: bool + """ + return getattr(error, "status_code", None) == http_constants.StatusCodes.BAD_REQUEST + + +def _build_outbound_token( + resource_id: str, + query: Any, + feed_range_epk: routing_range.Range, + entries: Iterable[Tuple[routing_range.Range, Optional[str]]], + query_hash: Optional[str] = None, + feedrange_hash: Optional[str] = None, +) -> str: + """Build and base64-encode the outbound continuation token from a + queue of ``(range, backend_continuation)`` entries. + + Persists the queue as the wire-format ``c`` list in head-first order. + + :param resource_id: Collection resource ID. + :type resource_id: str + :param query: Query spec used for hashing. + :type query: str or dict + :param feed_range_epk: Original feed range for the request. + :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range + :param entries: Ordered ``(range, bc)`` pairs to serialize. + :type entries: Iterable[tuple[~azure.cosmos._routing.routing_range.Range, Optional[str]]] + :param query_hash: Optional precomputed query hash to persist in the token envelope. + :type query_hash: Optional[str] + :param feedrange_hash: Optional precomputed feed_range hash to persist in the token envelope. + :type feedrange_hash: Optional[str] + :returns: Encoded continuation token. + :rtype: str + """ + payload = { + _FIELD_VERSION: _TOKEN_VERSION, + _FIELD_COLLECTION_RID: resource_id, + _FIELD_QUERY_HASH: query_hash or _hash_query_spec(query), + _FIELD_FEEDRANGE_HASH: feedrange_hash or _hash_feed_range(feed_range_epk), + _FIELD_CONTINUATIONS: [ + {"min": r.min, "max": r.max, _FIELD_BACKEND_CONTINUATION: bc} + for r, bc in entries + ], + } + return _encode_token(payload) + + +# ----- Pagination-loop helpers shared by sync and async ------------------ +def _normalize_max_item_count(raw_max_item_count: Any) -> Optional[int]: + """Normalize the caller's ``maxItemCount`` to a positive page-size cap or + ``None`` (unbounded). + + Three rules, applied in order: + * ``None`` (caller did not set one) -> ``None`` (unbounded; backend + decides page size). + * Non-numeric values (e.g. a malformed string) -> ``None``. Raising + here would change the error surface for callers that previously + worked by accident; ``None`` keeps them working. + * Any value ``<= 0`` -> ``None``. A zero or negative cap would make + the pagination loop emit a continuation token without issuing any + backend POST, which can produce an empty-page-with-continuation + cycle on the caller side. + + :param raw_max_item_count: Raw ``maxItemCount`` value from options. + :type raw_max_item_count: Any + """ + if raw_max_item_count is None: + return None + try: + normalized = int(raw_max_item_count) + except (TypeError, ValueError): + return None + if normalized <= 0: + return None + return normalized + + +def _count_page_items_from_partial_result( + partial_result: Optional[dict[str, Any]], + query: Any, +) -> int: + """Return how many logical items should consume the remaining page-item count. + + Aggregate partial rows are merge-input fragments, not final logical + rows, so they should not consume page items and force an early break. + + :param partial_result: One backend POST result. + :type partial_result: Optional[dict[str, Any]] + :param query: Query text or query spec dictionary. + :type query: Any + :returns: Number of items to subtract from the remaining page-item count. + :rtype: int + """ + if not partial_result: + return 0 + docs = partial_result.get("Documents") + if not isinstance(docs, list): + return 0 + if len(docs) != 1: + # Cosmos backend invariant: aggregate partial fragments are emitted as + # single-element arrays. Non-singleton arrays are treated as regular rows. + return len(docs) + + # Aggregate partials must be merged across overlaps before they count as rows. + if _classify_aggregate_partial(docs, query) != _AggregatePartialClassification.NONE: + return 0 + return 1 + + +def _update_no_progress_page_count( + current_no_progress_count: int, + page_items_returned: int, + previous_feedrange: Optional[routing_range.Range], + previous_backend_continuation: Optional[str], + head_feedrange: Optional[routing_range.Range], + head_backend_continuation: Optional[str], +) -> int: + """Track consecutive empty pages that still carry continuation. + + :param current_no_progress_count: Current consecutive no-progress page count. + :type current_no_progress_count: int + :param page_items_returned: Number of logical page items returned this iteration. + :type page_items_returned: int + :param previous_feedrange: Feedrange before processing this response. + :type previous_feedrange: Optional[~azure.cosmos._routing.routing_range.Range] + :param previous_backend_continuation: Backend continuation before response. + :type previous_backend_continuation: Optional[str] + :param head_feedrange: Feedrange after processing this response. + :type head_feedrange: Optional[~azure.cosmos._routing.routing_range.Range] + :param head_backend_continuation: Backend continuation after response. + :type head_backend_continuation: Optional[str] + :returns: Updated consecutive no-progress page count. + :rtype: int + """ + def _range_bounds(rng: Optional[routing_range.Range]) -> Optional[Tuple[str, str]]: + if rng is None: + return None + return rng.min, rng.max + + if page_items_returned > 0: + return 0 + if head_backend_continuation is None: + return 0 + if _range_bounds(head_feedrange) != _range_bounds(previous_feedrange): + return 0 + if head_backend_continuation != previous_backend_continuation: + return 0 + + # No logical rows and no cursor/feedrange movement: caller made no progress. + return current_no_progress_count + 1 + + +def _increment_explode_iterations_or_raise(current_explode_iterations: int) -> int: + """Increment split re-resolution iteration count or raise on overflow. + + :param current_explode_iterations: Current explode-loop iteration count. + :type current_explode_iterations: int + :returns: Incremented explode-loop iteration count. + :rtype: int + """ + updated_iterations = current_explode_iterations + 1 + if updated_iterations > _MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS: + raise RuntimeError( + "Exceeded {} split re-resolution iterations while expanding overlapping " + "feed ranges. This indicates a stale/corrupted routing map response." + .format(_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS) + ) + return updated_iterations + + +def _apply_feedrange_request_headers( + req_headers: MutableMapping[str, Any], + overlapping: List[dict], + partition_scope: routing_range.Range, + head_feedrange: routing_range.Range, + page_size_hint: Optional[int], + inbound_continuation: Optional[str], +) -> None: + """Populate ``req_headers`` for one backend POST against + ``head_feedrange`` and the partition currently serving it. + + Routes by ``PartitionKeyRangeID`` and only adds the EPK filter + headers when the current feed range is a strict sub-range of the + partition. Page size and continuation are explicitly set or + cleared so leftover state from the previous iteration cannot leak. + + :param req_headers: Mutable request headers to populate. + :type req_headers: MutableMapping[str, Any] + :param overlapping: Overlapping partition-range dictionaries. + :type overlapping: list[dict] + :param partition_scope: Union scope for overlapping partitions. + :type partition_scope: ~azure.cosmos._routing.routing_range.Range + :param head_feedrange: Feed range for the current backend request. + :type head_feedrange: ~azure.cosmos._routing.routing_range.Range + :param page_size_hint: Request page-size hint for the backend POST. + :type page_size_hint: Optional[int] + :param inbound_continuation: Continuation token for backend request. + :type inbound_continuation: Optional[str] + """ + pkr_id = overlapping[0]["id"] + req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = pkr_id + + is_full_partition = ( + len(overlapping) == 1 + and head_feedrange.min == partition_scope.min + and head_feedrange.max == partition_scope.max + ) + if is_full_partition: + req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) + req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) + else: + req_headers[http_constants.HttpHeaders.StartEpkString] = head_feedrange.min + req_headers[http_constants.HttpHeaders.EndEpkString] = head_feedrange.max + req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" + + if page_size_hint is not None: + req_headers[http_constants.HttpHeaders.PageSize] = str(page_size_hint) + else: + req_headers.pop(http_constants.HttpHeaders.PageSize, None) + + if inbound_continuation is not None: + req_headers[http_constants.HttpHeaders.Continuation] = inbound_continuation + else: + req_headers.pop(http_constants.HttpHeaders.Continuation, None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 5c8a3354061d..8f75b106573f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -23,10 +23,11 @@ """Document client class for the Azure Cosmos database service. """ +import logging import os from urllib.parse import urlparse import uuid -from typing import Callable, Any, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast from typing_extensions import TypedDict from urllib3.util.retry import Retry @@ -54,7 +55,23 @@ from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState from .._change_feed.feed_range_internal import FeedRangeInternalEpk -from .._routing import routing_range +from .._routing.feed_range_continuation import ( + _FeedRangePaginationState, + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _apply_feedrange_request_headers, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _hash_feed_range, + _hash_query_spec, + _increment_explode_iterations_or_raise, + _normalize_max_item_count, + _should_attempt_legacy_bridge_fallback, + _update_no_progress_page_count, + _validate_token_identity, + _write_query_outbound_continuation, +) from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants from .._cosmos_responses import CosmosDict, CosmosList @@ -80,6 +97,8 @@ from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from .._range_partition_resolver import RangePartitionResolver +_LOGGER = logging.getLogger(__name__) + @@ -140,6 +159,8 @@ def __init__( # pylint: disable=too-many-statements """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection + emit_structured_env = os.environ.get(Constants.EMIT_STRUCTURED_CONTINUATION_PK_CONFIG, "") + self._emit_structured_continuation_pk = emit_structured_env.strip().lower() in ("1", "true", "yes", "on") self.master_key: Optional[str] = None self.resource_tokens: Optional[Mapping[str, Any]] = None self.aad_credentials: Optional[AsyncTokenCredential] = None @@ -2082,6 +2103,7 @@ async def _Batch( options.get("partitionKey", None)) request_params.set_excluded_location_from_options(options) await base.set_session_token_header_async(self, headers, path, request_params, options) + request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[list[dict[str, Any]], CaseInsensitiveDict], result) @@ -2991,6 +3013,20 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, # we need to set operation_state in kwargs as that's where it is looked at while sending the request kwargs.setdefault("timeout", timeout) + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + + def _capture_internal_headers(headers: Mapping[str, Any]) -> None: + # `internal_headers_capture` is Optional[Dict]; checking it + # for None once inside this helper lets the type checker + # treat it as a plain Dict for the .clear()/.update() calls + # below, and keeps every call site to a single line. + if internal_headers_capture is None: + return + internal_headers_capture.clear() + internal_headers_capture.update(headers) + if query: __GetBodiesFromQueryResult = result_fn else: @@ -3034,9 +3070,11 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: feed_options) request_params.headers = headers - result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(headers, result, last_response_headers) + result, get_response_headers = await self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = get_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(get_response_headers) + self._UpdateSessionIfRequired(headers, result, get_response_headers) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) @@ -3066,83 +3104,324 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: await base.set_session_token_header_async(self, req_headers, path, request_params, options, partition_key_range_id) - # Check if the over lapping ranges can be populated + # Check if the overlapping ranges can be populated feed_range_epk = None + is_full_pk_structured_scope = False + legacy_partition_key_header = req_headers.get(http_constants.HttpHeaders.PartitionKey) if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() elif options.get("partitionKey") is not None and container_property is not None: - # check if query has prefix partition key partition_key_value = options["partitionKey"] partition_key_obj = _build_partition_key_from_properties(container_property) if partition_key_obj._is_prefix_partition_key(partition_key_value): req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) partition_key_value = cast(_SequentialPartitionKeyType, partition_key_value) feed_range_epk = partition_key_obj._get_epk_range_for_prefix_partition_key(partition_key_value) + else: + req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) + # Full-PK returns a single-value inclusive range; normalize to + # [min, max) before routing-map overlap resolution. + feed_range_epk = partition_key_obj._get_epk_range_for_partition_key( + partition_key_value + ).to_normalized_range() + is_full_pk_structured_scope = True if feed_range_epk is not None: - over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feed_range_epk], - options) - results: dict[str, Any] = {} - # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over - # lapping physical partition. The EPK sub range will be one of four: - # 1) Will have a range min equal to the feed range EPK min, and a range max equal to the over lapping - # partition - # 2) Will have a range min equal to the over lapping partition range min, and a range max equal to the - # feed range EPK range max. - # 3) will match exactly with the current over lapping physical partition, so we just return the over lapping - # physical partition's partition key id. - # 4) Will equal the feed range EPK since it is a sub range of a single physical partition - for over_lapping_range in over_lapping_ranges: - single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) - # Since the range min and max are all Upper Cased string Hex Values, - # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), - range_max=min(single_range.max, feed_range_epk.max), - isMinInclusive=True, isMaxInclusive=False) - if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: - # The Epk Sub Range spans exactly one physical partition - # In this case we can route to the physical pk range id - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] + if id_ is None: + raise ValueError("resource_id is required for feed_range continuation.") + # The None-check above already narrows ``id_`` to ``str`` for the + # rest of this block. Bind it to a clearly-named local so the + # feed_range helpers below read as ``resource_id_str`` instead + # of the generic ``id_``. + resource_id_str: str = id_ + # Decode and validate inbound continuation for this request. + # ``None`` means start from the beginning of the requested + # feed range. + page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) + query_hash = _hash_query_spec(query) + feedrange_hash = _hash_feed_range(feed_range_epk) + should_emit_structured_full_pk = self._emit_structured_continuation_pk + inbound_serialized_continuation = options.get("continuation") + inbound_token_payload = _decode_token(inbound_serialized_continuation) + legacy_bridge_in_use = False + legacy_fallback_attempted = False + if inbound_serialized_continuation and inbound_token_payload is None: + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + legacy_bridge_in_use = True else: - # The Epk Sub Range spans less than a single physical partition - # In this case we route to the physical partition and - # pass the epk sub range to the headers to filter within partition - req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"] - req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min - req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max - req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, last_response_headers = await self.__Post( - path, - request_params, + _LOGGER.warning( + "Feed-range query continuation token is not in the supported structured format; " + "restarting this feed_range query from the beginning." + ) + if inbound_token_payload is not None: + _validate_token_identity( + inbound_token_payload, + resource_id_str, query, - req_headers, - **kwargs + feed_range_epk, + expected_query_hash=query_hash, + expected_feedrange_hash=feedrange_hash, ) - self.last_response_headers = last_response_headers - self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) - - # Introducing a temporary complex function into a critical path to handle aggregated queries, - # during splits as a precaution falling back to the original logic if anything goes wrong - try: - results = base._merge_query_results(results, partial_result, query) - except Exception: # pylint: disable=broad-exception-caught - # If the new merge logic fails, fall back to the original logic. - if results: - results["Documents"].extend(partial_result["Documents"]) - else: - results = partial_result - - if response_hook: - response_hook(self.last_response_headers, partial_result) - # if the prefix partition query has results lets return it - if results: - return __GetBodiesFromQueryResult(results) - - result, last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) - self.last_response_headers = last_response_headers + pagination_state = _FeedRangePaginationState.from_inbound( + inbound_token_payload, page_size_hint + ) + elif legacy_bridge_in_use and inbound_serialized_continuation: + pagination_state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + feed_range_epk, + inbound_serialized_continuation, + page_size_hint, + ) + else: + first_overlaps = await self._routing_map_provider.get_overlapping_ranges( + id_, [feed_range_epk], dict(options) + ) + all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) + if not all_feedranges: + # The input feed_range overlaps no current physical + # partition. Fall through to the regular __Post path + # below so cross-partition gating can still surface + # (e.g. an empty ``partition_key=[]`` prefix with + # ``enableCrossPartitionQuery=False`` must raise + # ``BAD_REQUEST``). + pagination_state = None + else: + pagination_state = _FeedRangePaginationState.from_derived_feedranges( + all_feedranges, + page_size_hint, + ) + + if pagination_state is not None: + results: dict[str, Any] = {} + feedrange_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + consecutive_no_progress_pages = 0 + + def _checkpoint_and_reraise(error: Exception) -> None: + # Intentionally broad: stamp the latest resumable checkpoint + # for any mid-page failure, then re-raise the original error. + self.last_response_headers = feedrange_response_headers + try: + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + should_emit_structured_full_pk, + query_hash, + feedrange_hash, + ) + except Exception as continuation_write_error: # pylint: disable=broad-exception-caught + _LOGGER.warning( + "Failed to write continuation while handling query POST failure: %s", + continuation_write_error, + ) + raise error + + # NOTE: Keep this feed_range pagination loop in sync with + # ``azure/cosmos/_cosmos_client_connection.py::__QueryFeed``. + while pagination_state.can_issue_request(): + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. + overlapping = await self._routing_map_provider.get_overlapping_ranges( + id_, [head_feedrange], dict(options) + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = await self._routing_map_provider.get_overlapping_ranges( + id_, [head_feedrange], dict(options) + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + await base.set_session_token_header_async( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) + + try: + backend_query_result, backend_response_headers = await self.__Post( + path, + request_params, + query, + req_headers, + **kwargs + ) + except exceptions.CosmosHttpResponseError as post_error: + if ( + legacy_bridge_in_use + and not legacy_fallback_attempted + and _should_attempt_legacy_bridge_fallback(post_error) + ): + legacy_fallback_attempted = True + req_headers.pop(http_constants.HttpHeaders.PartitionKeyRangeID, None) + req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) + req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) + req_headers.pop(http_constants.HttpHeaders.ReadFeedKeyType, None) + if legacy_partition_key_header is not None: + req_headers[http_constants.HttpHeaders.PartitionKey] = legacy_partition_key_header + req_headers[http_constants.HttpHeaders.Continuation] = inbound_serialized_continuation + await base.set_session_token_header_async( + self, req_headers, path, request_params, options, partition_key_range_id + ) + try: + backend_query_result, backend_response_headers = await self.__Post( + path, + request_params, + query, + req_headers, + **kwargs + ) + except Exception as fallback_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(fallback_error) + self.last_response_headers = backend_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired( + req_headers, backend_query_result, backend_response_headers + ) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + return __GetBodiesFromQueryResult(backend_query_result) + _checkpoint_and_reraise(post_error) + except Exception as post_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(post_error) + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, + ) + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + + previous_feedrange = pagination_state.head_range + previous_backend_continuation = pagination_state.head_bc + page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_hook: + response_hook(backend_response_headers, backend_query_result) + pagination_state.apply_post_result( + page_items_returned, + backend_response_headers.get(http_constants.HttpHeaders.Continuation), + ) + consecutive_no_progress_pages = _update_no_progress_page_count( + consecutive_no_progress_pages, + page_items_returned, + previous_feedrange, + previous_backend_continuation, + pagination_state.head_range, + pagination_state.head_bc, + ) + if ( + consecutive_no_progress_pages >= _MAX_CONSECUTIVE_NO_PROGRESS_PAGES + and consecutive_no_progress_pages % _MAX_CONSECUTIVE_NO_PROGRESS_PAGES == 0 + ): + # Warning-only: do not fail fast here. + current_head = pagination_state.head_range + head_min = current_head.min if current_head else "" + head_max = current_head.max if current_head else "" + _LOGGER.warning( + "Feed-range query has returned 0 items for %s consecutive continuation pages " + "with the same continuation token and partition key range [%s, %s); continuing scan.", + consecutive_no_progress_pages, + head_min, + head_max, + ) + + # maxItemCount is a per-request hint. Return this SDK page + # after the first non-empty logical result instead of filling + # an exact target count by issuing extra backend requests. + if page_items_returned > 0: + break + + # Pagination loop is done — write the final outbound + # continuation (or clear the header if the queue is fully + # drained) so the caller's ``by_page`` loop terminates. + _write_query_outbound_continuation( + feedrange_response_headers, + pagination_state, + resource_id_str, + query, + feed_range_epk, + is_full_pk_structured_scope, + should_emit_structured_full_pk, + query_hash, + feedrange_hash, + ) + # End feed_range pagination block. + self.last_response_headers = feedrange_response_headers + # if the prefix partition query has results lets return it + if results: + if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: + index_metrics_raw = self.last_response_headers[http_constants.HttpHeaders.IndexUtilization] + self.last_response_headers[http_constants.HttpHeaders.IndexUtilization] = ( + _utils.get_index_metrics_info(index_metrics_raw)) + return __GetBodiesFromQueryResult(results) + return [] + + result, post_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = post_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(post_response_headers) + # update session for request mutates data on server side - self._UpdateSessionIfRequired(req_headers, result, last_response_headers) + self._UpdateSessionIfRequired(req_headers, result, post_response_headers) # TODO: this part might become an issue since HTTP/2 can return read-only headers if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 9cb5578a8e78..b46e9a1a340f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -917,6 +917,7 @@ def query_items( # pylint:disable=docstring-missing-param # Get container property and init client container caches container_properties = self._get_properties_with_options(feed_options) + kwargs["container_properties"] = container_properties # Update 'feed_options' from 'kwargs' if utils.valid_key_value_exist(kwargs, "enable_cross_partition_query"): diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py index a7bd88a35253..c2dc81fbf01a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition.py @@ -553,6 +553,115 @@ def test_partitioned_collection_prefix_partition_query_subpartition(self): self.assertTrue("Cross partition query is required but disabled" in error.message) + def test_partitioned_collection_full_partition_key_pagination_resume_subpartition(self): + created_db = self.databaseForTest + collection_id = 'test_partitioned_collection_full_partition_key_resume_MH ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + partition_key = ['CA', 'Oxnard', '93033'] + total_items = 35 + for i in range(total_items): + created_collection.create_item( + body={ + 'id': 'full-pk-doc-{0:03d}'.format(i), + 'state': partition_key[0], + 'city': partition_key[1], + 'zipcode': partition_key[2] + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=partition_key, + max_item_count=10 + ) + pager = query_iterable.by_page() + first_page = list(pager.next()) + self.assertGreater(len(first_page), 0) + token = pager.continuation_token + self.assertIsNotNone(token) + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resume_pager = query_iterable.by_page(token) + resumed_remaining_ids = [] + for page in resume_pager: + resumed_remaining_ids.extend(item['id'] for item in page) + + self.assertListEqual(expected_remaining_ids, resumed_remaining_ids) + + baseline_ids = [ + item['id'] for item in created_collection.query_items(query=query, partition_key=partition_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + self.assertListEqual(baseline_ids, fetched_ids) + + created_db.delete_container(created_collection.id) + + def test_partitioned_collection_prefix_partition_key_pagination_resume_subpartition(self): + created_db = self.databaseForTest + collection_id = 'test_partitioned_collection_prefix_partition_key_resume_MH ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + for i in range(30): + created_collection.create_item( + body={ + 'id': 'ca-doc-{0:03d}'.format(i), + 'state': 'CA', + 'city': 'city-{0}'.format(i % 5), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + for i in range(5): + created_collection.create_item( + body={ + 'id': 'wa-doc-{0:03d}'.format(i), + 'state': 'WA', + 'city': 'city-{0}'.format(i), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=['CA'], + max_item_count=7 + ) + pager = query_iterable.by_page() + first_page = list(pager.next()) + self.assertGreater(len(first_page), 0) + token = pager.continuation_token + self.assertIsNotNone(token) + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resume_pager = query_iterable.by_page(token) + resumed_remaining_ids = [] + for page in resume_pager: + resumed_remaining_ids.extend(item['id'] for item in page) + + self.assertListEqual(expected_remaining_ids, resumed_remaining_ids) + + baseline_ids = [ + item['id'] for item in created_collection.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + self.assertListEqual(baseline_ids, fetched_ids) + + created_db.delete_container(created_collection.id) + def test_partition_key_range_overlap_subpartition(self): Id = 'id' MinInclusive = 'minInclusive' diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py index bcccc4735c63..759f3193d108 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_subpartition_async.py @@ -551,6 +551,115 @@ async def test_partitioned_collection_prefix_partition_query_subpartition_async( await created_db.delete_container(created_collection.id) + async def test_partitioned_collection_full_partition_key_pagination_resume_subpartition_async(self): + created_db = self.database_for_test + collection_id = 'test_partitioned_collection_full_partition_key_resume_MH_async ' + str(uuid.uuid4()) + created_collection = await created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + partition_key = ['CA', 'Oxnard', '93033'] + total_items = 35 + for i in range(total_items): + await created_collection.create_item( + body={ + 'id': 'full-pk-doc-{0:03d}'.format(i), + 'state': partition_key[0], + 'city': partition_key[1], + 'zipcode': partition_key[2] + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=partition_key, + max_item_count=10 + ) + pager = query_iterable.by_page() + first_page = [item async for item in await pager.__anext__()] + assert len(first_page) > 0 + token = pager.continuation_token + assert token is not None + + expected_remaining_ids = [] + async for page in pager: + expected_remaining_ids.extend([item['id'] async for item in page]) + + resumed_remaining_ids = [] + resume_pager = query_iterable.by_page(token) + async for page in resume_pager: + resumed_remaining_ids.extend([item['id'] async for item in page]) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] async for item in created_collection.query_items(query=query, partition_key=partition_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + + await created_db.delete_container(created_collection.id) + + async def test_partitioned_collection_prefix_partition_key_pagination_resume_subpartition_async(self): + created_db = self.database_for_test + collection_id = 'test_partitioned_collection_prefix_partition_key_resume_MH_async ' + str(uuid.uuid4()) + created_collection = await created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash) + ) + + for i in range(30): + await created_collection.create_item( + body={ + 'id': 'ca-doc-{0:03d}'.format(i), + 'state': 'CA', + 'city': 'city-{0}'.format(i % 5), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + for i in range(5): + await created_collection.create_item( + body={ + 'id': 'wa-doc-{0:03d}'.format(i), + 'state': 'WA', + 'city': 'city-{0}'.format(i), + 'zipcode': 'zip-{0:03d}'.format(i) + } + ) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_collection.query_items( + query=query, + partition_key=['CA'], + max_item_count=7 + ) + pager = query_iterable.by_page() + first_page = [item async for item in await pager.__anext__()] + assert len(first_page) > 0 + token = pager.continuation_token + assert token is not None + + expected_remaining_ids = [] + async for page in pager: + expected_remaining_ids.extend([item['id'] async for item in page]) + + resumed_remaining_ids = [] + resume_pager = query_iterable.by_page(token) + async for page in resume_pager: + resumed_remaining_ids.extend([item['id'] async for item in page]) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] async for item in created_collection.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + + await created_db.delete_container(created_collection.id) + async def test_partition_key_range_subpartition_overlap(self): Id = 'id' MinInclusive = 'minInclusive' diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py new file mode 100644 index 000000000000..10aa7b4f2e01 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -0,0 +1,1523 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Unit tests for ``azure.cosmos._routing.feed_range_continuation``. + + +* ``TestTokenRoundTrip`` - ``_decode_token(_encode_token(p))`` returns + a structurally-equivalent dict; the wire form is valid base64 of + valid JSON; the JSON contains the five envelope keys (``v`` / ``cr`` / + ``qh`` / ``frh`` / ``c``) and a per-entry ``bc`` inside each + ``c[i]``. The wire format has NO privileged "current" slot — iteration + position is reconstructed in memory from ``c[0]``. +* ``TestVersionMismatchRejected`` - a token whose ``v`` field is set + but is not the SDK's current version raises ``ValueError`` with a + message naming both the offending and the supported version. +* ``TestIdentityFingerprintMismatch`` - a valid v=1 token whose ``cr`` + / ``qh`` / ``frh`` fingerprints disagree with the current request + raises ``ValueError`` with a message naming the failing field. + (Validation lives in the call sites; this test exercises the same + hash-based equality the call sites use.) +* ``TestExplodeOnMultiOverlap`` - when a saved feedrange resolves to + more than one physical partition on resume (the post-split case), + the call site must slice the feedrange into one sub-feedrange per + child before POSTing. These tests pin the geometry of that slice + without touching the network. +""" + +import base64 +import json + +import pytest + +from azure.cosmos import _base +from azure.cosmos import http_constants +from azure.cosmos._query_aggregate_utils import ( + _AggregatePartialClassification, + _classify_aggregate_partial, + _extract_outer_select_value_projection, + _get_select_value_aggregate_function, + _strip_sql_block_comments, +) +from azure.cosmos._routing import routing_range +from azure.cosmos._routing.feed_range_continuation import ( + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES, + _MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS, + _FeedRangePaginationState, + _apply_feedrange_request_headers, + _build_outbound_token, + _build_scope_from_overlaps, + _count_page_items_from_partial_result, + _decode_token, + _derive_initial_feedranges, + _encode_token, + _hash_feed_range, + _hash_query_spec, + _stable_hash_128, + _normalize_max_item_count, + _increment_explode_iterations_or_raise, + _update_no_progress_page_count, + _validate_token_identity, + _FIELD_BACKEND_CONTINUATION, + _FIELD_COLLECTION_RID, + _FIELD_CONTINUATIONS, + _FIELD_FEEDRANGE_HASH, + _FIELD_QUERY_HASH, + _FIELD_VERSION, + _TOKEN_VERSION, +) + + +# Fixed inputs reused across the round-trip / mismatch tests so each +# assertion compares against a known-good baseline. +# cspell:ignore AOXB BFFFFFFFFFFFFFFF BAAAAAAAAAA +_RID = "Yxs1AOXBSp4=" +_QUERY = {"query": "SELECT * FROM c WHERE c.x = @x", + "parameters": [{"name": "@x", "value": 7}]} +_FEED_RANGE = routing_range.Range( + range_min="3FFFFFFFFFFFFFFF", + range_max="BFFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_HEAD_FEEDRANGE = routing_range.Range( + range_min="3FFFFFFFFFFFFFFF", + range_max="7FFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_REMAINING_FEEDRANGE = routing_range.Range( + range_min="7FFFFFFFFFFFFFFF", + range_max="BFFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, +) +_BACKEND_CONT = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567" + + +def _mk_range(mn: str, mx: str) -> routing_range.Range: + return routing_range.Range(range_min=mn, range_max=mx, isMinInclusive=True, isMaxInclusive=False) + + +def _make_valid_token_payload() -> dict: + """Build a structurally-complete v=1 token payload over the fixtures. + + The wire format is a single ordered ``c`` list of + ``{min, max, bc}`` entries with no privileged "current" slot — + iteration position is reconstructed in memory from ``c[0]``. Each + entry carries its own ``bc``; the sequential loop only ever sets a + non-null ``bc`` for ``c[0]`` and leaves later entries' ``bc`` null. + """ + return { + _FIELD_VERSION: _TOKEN_VERSION, + _FIELD_COLLECTION_RID: _RID, + _FIELD_QUERY_HASH: _hash_query_spec(_QUERY), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(_FEED_RANGE), + _FIELD_CONTINUATIONS: [ + { + "min": _HEAD_FEEDRANGE.min, + "max": _HEAD_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: _BACKEND_CONT, + }, + { + "min": _REMAINING_FEEDRANGE.min, + "max": _REMAINING_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: None, + }, + ], + } + + +class TestStableHash128MurmurRegression: + """Pin one MurmurHash3-128 output as a regression guard against + accidental algorithm drift. The other hash tests only assert + determinism / inequality between distinct inputs; this one pins + exact bytes for one fixed input.""" + + def test_known_input_produces_known_murmur_digest(self): + assert _stable_hash_128(b"feed_range_continuation_token_regression") == ( + "ce0130ea460342256309b38dfdbc9c50" + ) + + +# ---------------------------------------------------------------------- # +# Token round-trip +# ---------------------------------------------------------------------- # +class TestTokenRoundTrip: + """``_encode_token`` -> ``_decode_token`` is structurally lossless and + the wire form is base64-encoded JSON containing all seven required + fields.""" + + def test_round_trip_preserves_all_fields(self): + payload = _make_valid_token_payload() + wire = _encode_token(payload) + decoded = _decode_token(wire) + assert decoded == payload + + def test_wire_form_is_base64_of_json(self): + payload = _make_valid_token_payload() + wire = _encode_token(payload) + # Wire form must be ASCII-safe and base64-decodable; the decoded + # bytes must be valid UTF-8 JSON; the JSON must be a dict. + raw = base64.b64decode(wire, validate=True) + as_json = json.loads(raw.decode("utf-8")) + assert isinstance(as_json, dict) + + def test_wire_form_contains_five_envelope_keys_and_per_entry_bc(self): + # The envelope itself has FIVE top-level keys (no top-level + # ``bc``, no privileged ``cf``/``rf`` split); ``bc`` lives + # inside each ``c[i]`` so a future non-sequential / parallel + # loop can record one backend continuation per sub-range + # without a wire-format bump. + payload = _make_valid_token_payload() + wire = _encode_token(payload) + decoded_json = json.loads(base64.b64decode(wire, validate=True).decode("utf-8")) + envelope_required = { + _FIELD_VERSION, + _FIELD_COLLECTION_RID, + _FIELD_QUERY_HASH, + _FIELD_FEEDRANGE_HASH, + _FIELD_CONTINUATIONS, + } + assert envelope_required == set(decoded_json.keys()) + assert _FIELD_BACKEND_CONTINUATION not in decoded_json, ( + "envelope must NOT carry a top-level 'bc'; bc is per-entry" + ) + assert "cf" not in decoded_json, ( + "envelope must NOT carry a privileged 'cf' slot" + ) + assert "rf" not in decoded_json, ( + "envelope must NOT carry a 'rf' tail; sub-ranges live in a single 'c' list" + ) + assert isinstance(decoded_json[_FIELD_CONTINUATIONS], list) + assert len(decoded_json[_FIELD_CONTINUATIONS]) >= 1 + for entry in decoded_json[_FIELD_CONTINUATIONS]: + assert _FIELD_BACKEND_CONTINUATION in entry + + def test_build_outbound_token_emits_valid_token(self): + wire = _build_outbound_token( + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + entries=[ + (_HEAD_FEEDRANGE, _BACKEND_CONT), + (_REMAINING_FEEDRANGE, None), + ], + ) + decoded = _decode_token(wire) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + assert decoded[_FIELD_COLLECTION_RID] == _RID + assert decoded[_FIELD_QUERY_HASH] == _hash_query_spec(_QUERY) + assert decoded[_FIELD_FEEDRANGE_HASH] == _hash_feed_range(_FEED_RANGE) + # Head of the ``c`` list == the in-flight slice; tail == queued. + assert decoded[_FIELD_CONTINUATIONS] == [ + { + "min": _HEAD_FEEDRANGE.min, + "max": _HEAD_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: _BACKEND_CONT, + }, + { + "min": _REMAINING_FEEDRANGE.min, + "max": _REMAINING_FEEDRANGE.max, + _FIELD_BACKEND_CONTINUATION: None, + }, + ] + assert _FIELD_BACKEND_CONTINUATION not in decoded + assert "cf" not in decoded + assert "rf" not in decoded + + def test_build_outbound_token_uses_precomputed_hashes_without_rehash(self, monkeypatch): + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._hash_query_spec", + lambda _query: (_ for _ in ()).throw(AssertionError("query hash should not be recomputed")), + ) + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._hash_feed_range", + lambda _feed_range: (_ for _ in ()).throw(AssertionError("feed_range hash should not be recomputed")), + ) + + wire = _build_outbound_token( + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + entries=[(_HEAD_FEEDRANGE, _BACKEND_CONT)], + query_hash="precomputed-query-hash", + feedrange_hash="precomputed-feedrange-hash", + ) + decoded = _decode_token(wire) + assert decoded is not None + assert decoded[_FIELD_QUERY_HASH] == "precomputed-query-hash" + assert decoded[_FIELD_FEEDRANGE_HASH] == "precomputed-feedrange-hash" + + def test_per_entry_backend_continuations_coexist(self): + # The shape that motivated the flat ``c`` list: a future + # non-sequential / parallel-fetch loop emits a token where + # multiple entries each carry their own non-null backend + # continuation. Today's sequential loop never produces this + # state, but the wire shape must already support it so that + # when parallel fetch lands no version bump is needed. + wire = _build_outbound_token( + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + entries=[ + (_HEAD_FEEDRANGE, "B-cont-5"), + (_REMAINING_FEEDRANGE, "A-cont-5"), + ], + ) + decoded = _decode_token(wire) + assert decoded is not None + entries = decoded[_FIELD_CONTINUATIONS] + assert entries[0][_FIELD_BACKEND_CONTINUATION] == "B-cont-5" + assert entries[1][_FIELD_BACKEND_CONTINUATION] == "A-cont-5" + + def test_none_and_empty_inputs_decode_to_none(self): + # An empty / missing continuation must NOT raise - the call site + # treats it as "first call, derive feedranges from the routing map". + assert _decode_token(None) is None + assert _decode_token("") is None + + +# ---------------------------------------------------------------------- # +# Version-mismatch rejection +# ---------------------------------------------------------------------- # +class TestVersionMismatchRejected: + """A token that decodes as our shape but with a non-current ``v`` + raises ``ValueError`` rather than being silently misinterpreted.""" + + def test_future_version_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_VERSION] = 999 + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + msg = str(excinfo.value) + assert "999" in msg + assert str(_TOKEN_VERSION) in msg + + def test_zero_version_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_VERSION] = 0 + wire = _encode_token(payload) + with pytest.raises(ValueError): + _decode_token(wire) + + def test_missing_continuations_list_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS] + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + assert _FIELD_CONTINUATIONS in str(excinfo.value) + + def test_empty_continuations_list_raises(self): + # An empty ``c`` list cannot legitimately appear on the wire: + # the producer clears the outbound continuation header in the + # drained case rather than emitting a token with no entries. + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS] = [] + wire = _encode_token(payload) + with pytest.raises(ValueError) as excinfo: + _decode_token(wire) + assert _FIELD_CONTINUATIONS in str(excinfo.value) + + +class TestMalformedV1TokenRejected: + """Malformed v1 tokens should raise ValueError at decode time. + + This prevents downstream call-sites from seeing KeyError when indexing + required identity and feedrange fields. + """ + + def test_missing_cr_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_COLLECTION_RID] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "cr" in str(excinfo.value) + + def test_missing_qh_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_QUERY_HASH] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "qh" in str(excinfo.value) + + def test_missing_frh_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_FEEDRANGE_HASH] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "frh" in str(excinfo.value) + + def test_missing_head_min_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS][0]["min"] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "{}[0].min".format(_FIELD_CONTINUATIONS) in str(excinfo.value) + + def test_malformed_tail_entry_raises(self): + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS] = [payload[_FIELD_CONTINUATIONS][0], "not-an-object"] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "{}[1]".format(_FIELD_CONTINUATIONS) in str(excinfo.value) + + def test_non_string_backend_continuation_raises(self): + # ``bc`` lives inside each ``c[i]``; a non-string value must raise. + payload = _make_valid_token_payload() + payload[_FIELD_CONTINUATIONS][0][_FIELD_BACKEND_CONTINUATION] = 123 + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + def test_envelope_level_backend_continuation_is_rejected(self): + # An older shape carried ``bc`` at the envelope root. The + # current shape moves ``bc`` inside each ``c[i]`` entry; a + # top-level ``bc`` must be rejected so a token from any earlier + # build fails loudly instead of being silently dropped. + payload = _make_valid_token_payload() + payload[_FIELD_BACKEND_CONTINUATION] = "envelope-level-bc" + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + + def test_missing_per_entry_backend_continuation_raises(self): + payload = _make_valid_token_payload() + del payload[_FIELD_CONTINUATIONS][0][_FIELD_BACKEND_CONTINUATION] + with pytest.raises(ValueError) as excinfo: + _decode_token(_encode_token(payload)) + assert "bc" in str(excinfo.value) + + +# ---------------------------------------------------------------------- # +# Identity-fingerprint mismatch rejection +# ---------------------------------------------------------------------- # +class TestIdentityFingerprintMismatch: + """A valid v=1 token replayed against a different collection / query / + feed_range produces a fingerprint mismatch the call site rejects. + + The hash helpers are deterministic and the call-site validators in + ``__QueryFeed`` compare ``inbound[_FIELD_*]`` to ``_hash_*(current)`` + and raise ``ValueError`` on mismatch.""" + + def test_collection_rid_mismatch_detected(self): + payload = _make_valid_token_payload() + decoded = _decode_token(_encode_token(payload)) + assert decoded is not None + # Same call-site shape: compare cr to current resource_id. + assert decoded[_FIELD_COLLECTION_RID] == _RID + assert decoded[_FIELD_COLLECTION_RID] != "different-collection-rid==" + + def test_query_text_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": "SELECT * FROM c WHERE c.x = @x AND c.y = 1", + "parameters": _QUERY["parameters"], + }) + assert original != modified, ( + "query-text change must produce a different hash so the " + "call site can reject the resume") + + def test_query_parameter_value_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": _QUERY["query"], + "parameters": [{"name": "@x", "value": 8}], + }) + assert original != modified + + def test_query_parameter_name_change_changes_hash(self): + original = _hash_query_spec(_QUERY) + modified = _hash_query_spec({ + "query": _QUERY["query"], + "parameters": [{"name": "@y", "value": 7}], + }) + assert original != modified + + def test_query_string_form_hashes_consistently(self): + # When the caller passes a plain string (no parameters) the hash + # must still be deterministic and stable. + h1 = _hash_query_spec("SELECT * FROM c") + h2 = _hash_query_spec("SELECT * FROM c") + h3 = _hash_query_spec("SELECT VALUE c FROM c") + assert h1 == h2 + assert h1 != h3 + + def test_feed_range_change_changes_hash(self): + original = _hash_feed_range(_FEED_RANGE) + wider = _hash_feed_range(routing_range.Range( + range_min=_FEED_RANGE.min, + range_max="FFFFFFFFFFFFFFFF", + isMinInclusive=True, isMaxInclusive=False, + )) + narrower = _hash_feed_range(routing_range.Range( + range_min=_FEED_RANGE.min, + range_max="9FFFFFFFFFFFFFFF", + isMinInclusive=True, isMaxInclusive=False, + )) + assert original != wider + assert original != narrower + assert wider != narrower + + def test_feed_range_hash_is_stable(self): + # Same feed_range -> same hash on every call (no random state). + h1 = _hash_feed_range(_FEED_RANGE) + h2 = _hash_feed_range(_FEED_RANGE) + assert h1 == h2 + + def test_feed_range_inclusivity_normalization_yields_same_hash(self): + # Hashing is based on the logical normalized EPK interval, so + # equivalent ranges with different bound inclusivity spellings + # must produce the same hash. + non_normalized = routing_range.Range( + range_min="0000000000000000", + range_max="7FFFFFFFFFFFFFFF", + isMinInclusive=False, + isMaxInclusive=True, + ) + normalized_image = non_normalized.to_normalized_range() + # Sanity: normalization actually changed something so the test + # is exercising the equivalence, not a no-op. + assert (normalized_image.isMinInclusive, normalized_image.isMaxInclusive) == (True, False) + + # The two forms must hash equal because they describe the same + # logical [min, max) interval after normalization. + assert _hash_feed_range(non_normalized) == _hash_feed_range(normalized_image) + + # And feeding the function the already-normalized form must + # yield the same digest a second time (idempotent). + assert _hash_feed_range(non_normalized) == _hash_feed_range( + routing_range.Range( + range_min=normalized_image.min, + range_max=normalized_image.max, + isMinInclusive=True, + isMaxInclusive=False, + ) + ) + + def test_call_site_replay_against_other_collection_raises(self): + """Drive the production validator (``_validate_token_identity``) + with a token built for ``_RID`` and resume against a different + collection rid. It must raise ``ValueError`` whose message names + the failing field, matching the call-site contract in + ``__QueryFeed`` (sync and async).""" + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id="different-collection-rid==", + query=_QUERY, + feed_range_epk=_FEED_RANGE, + ) + assert "collection" in str(excinfo.value).lower() + + def test_call_site_replay_with_different_query_raises(self): + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id=_RID, + query={"query": "SELECT c.id FROM c", "parameters": []}, + feed_range_epk=_FEED_RANGE, + ) + assert "query" in str(excinfo.value).lower() + + def test_call_site_replay_with_different_feed_range_raises(self): + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + other_feed_range = routing_range.Range( + range_min="0000000000000000", + range_max="3FFFFFFFFFFFFFFF", + isMinInclusive=True, + isMaxInclusive=False, + ) + with pytest.raises(ValueError) as excinfo: + _validate_token_identity( + inbound, + resource_id=_RID, + query=_QUERY, + feed_range_epk=other_feed_range, + ) + assert "feed_range" in str(excinfo.value).lower() + + def test_validate_token_identity_uses_precomputed_hashes_without_rehash(self, monkeypatch): + payload = _make_valid_token_payload() + inbound = _decode_token(_encode_token(payload)) + assert inbound is not None + + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._hash_query_spec", + lambda _query: (_ for _ in ()).throw(AssertionError("query hash should not be recomputed")), + ) + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._hash_feed_range", + lambda _feed_range: (_ for _ in ()).throw(AssertionError("feed_range hash should not be recomputed")), + ) + + _validate_token_identity( + inbound, + resource_id=_RID, + query=_QUERY, + feed_range_epk=_FEED_RANGE, + expected_query_hash=inbound[_FIELD_QUERY_HASH], + expected_feedrange_hash=inbound[_FIELD_FEEDRANGE_HASH], + ) + + +# ---------------------------------------------------------------------- # +# Explode-on-multi-overlap - post-split fan-out unit contract +# ---------------------------------------------------------------------- # +class TestExplodeOnMultiOverlap: + """Post-split fan-out contract for the resume path. + + Setup: a saved feedrange ``[A, C)`` lived inside one physical + partition on the day the token was emitted. By the time the token + is resumed, that partition has split at ``B`` into two children + ``X1 = [A, B)`` and ``X2 = [B, C)``. Re-resolving the saved + feedrange against the live routing map now returns two overlaps, + not one. + + If the call site just POSTed once against ``X1`` with + ``EndEpkString = C``, every row that physically lives on ``X2`` + would be silently dropped - the backend's EPK filter only returns + rows on the partition the request was routed to. + + The contract in ``__QueryFeed`` (sync and async): when + ``len(overlapping) > 1``, hand the saved feedrange and the new + children to ``_derive_initial_feedranges`` to get one sub-feedrange + per child (each the intersection of the child's range with the + saved feedrange). The first becomes the new ``head_feedrange``, + the rest are prepended to ``pending_feedranges``, and the + parent's backend continuation is dropped (it referenced the old + partition's id). The next loop iteration sees a single overlap and + falls through to the normal single-partition POST. + + The tests below pin four properties of that slice: + + * one sub-feedrange per child, + * sub-feedranges cover the saved feedrange end-to-end with no + gap and no overlap, + * order is by EPK ``min`` regardless of input order, + * each sub-feedrange resolves to exactly one child on the next + loop iteration (so the slice branch does not re-fire).""" + + @staticmethod + def _pkr(pkr_id: str, mn: str, mx: str) -> dict: + # The minimal partition_key_range dict shape that the routing + # map provider hands back. _build_scope_from_overlaps and + # _derive_initial_feedranges both consume this shape directly. + return {"id": pkr_id, "minInclusive": mn, "maxExclusive": mx} + + def test_two_child_split_slices_into_two_sub_feedranges(self): + # Saved feedrange covers the whole of the pre-split parent. + # After the split it resolves to two children X1 and X2; the + # slice must hand back one sub-feedrange per child. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + x1 = self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000") + x2 = self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0") + + sub_feedranges = _derive_initial_feedranges(saved_feedrange, [x1, x2]) + + assert len(sub_feedranges) == 2, ( + "Day-N resolution returned 2 children but the slice " + "produced {} sub-feedranges".format(len(sub_feedranges))) + assert (sub_feedranges[0].min, sub_feedranges[0].max) == ( + "05C1D9D533F364", "05C1D9E4000000"), ( + "first sub-feedrange should be the X1 slice of the saved feedrange") + assert (sub_feedranges[1].min, sub_feedranges[1].max) == ( + "05C1D9E4000000", "05C1D9F59FF5A0"), ( + "second sub-feedrange should be the X2 slice of the saved feedrange") + + def test_sub_feedranges_partition_parent_exactly(self): + # A wider variant: the saved feedrange sits inside an even + # bigger old partition that has since split into THREE children. + # The slice must still cover the saved feedrange end-to-end - + # every row that was returnable under the old layout must still + # be reachable under the new one, no gap (= missing rows) and + # no overlap (= duplicates). + saved_feedrange = routing_range.Range( + range_min="20", range_max="E0", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("c1", "00", "55"), + self._pkr("c2", "55", "AA"), + self._pkr("c3", "AA", "FF"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + bounds = [(s.min, s.max) for s in sub_feedranges] + assert bounds == [("20", "55"), ("55", "AA"), ("AA", "E0")], ( + "sub-feedranges must be the intersections of each child with " + "the saved feedrange; got {}".format(bounds)) + # First sub-feedrange starts where the saved feedrange starts; + # last one ends where it ends. Anything else loses rows at the + # edges. + assert bounds[0][0] == saved_feedrange.min + assert bounds[-1][1] == saved_feedrange.max + # And the sub-feedranges butt up against each other with no gap + # and no overlap. + for i in range(len(bounds) - 1): + assert bounds[i][1] == bounds[i + 1][0], ( + "sub-feedranges {} and {} have a gap or overlap at the " + "boundary; rows in between would be missed or " + "duplicated".format(bounds[i], bounds[i + 1])) + + def test_sub_feedranges_are_deterministically_ordered(self): + # The routing map provider doesn't promise any particular order + # when it returns the children. The call site prepends the + # leftover sub-feedranges to pending_feedranges, so if the + # order depended on what the provider happened to return, two + # different SDK processes resuming the same token could end up + # walking the children in different orders - and emit different + # outbound tokens halfway through. Pin EPK-min order regardless + # of input order. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + x1 = self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000") + x2 = self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0") + + forward = _derive_initial_feedranges(saved_feedrange, [x1, x2]) + reverse = _derive_initial_feedranges(saved_feedrange, [x2, x1]) + + assert ([(r.min, r.max) for r in forward] + == [(r.min, r.max) for r in reverse]), ( + "slice result depended on input child order; resuming the " + "same token from two processes would diverge") + + def test_each_sub_feedrange_resolves_to_a_single_child(self): + # Why the slice is correct end-to-end: after slicing, the + # NEW head_feedrange is X1's slice, and the next iteration + # of the __QueryFeed loop re-resolves it against the routing + # map. That re-resolution must come back with exactly one + # overlap (X1) - otherwise we'd loop into the slice branch a + # second time. Same for X2 once X1 is drained. This pins the + # invariant that each sub-feedrange routes cleanly to one + # partition, which is what lets the rest of the loop fall + # through to the single-partition POST. + saved_feedrange = routing_range.Range( + range_min="05C1D9D533F364", range_max="05C1D9F59FF5A0", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("X1", "05C1D9D533F364", "05C1D9E4000000"), + self._pkr("X2", "05C1D9E4000000", "05C1D9F59FF5A0"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + for sb in sub_feedranges: + owning_child = next(c for c in children + if c["minInclusive"] <= sb.min < c["maxExclusive"]) + overlaps, scope = _build_scope_from_overlaps([owning_child], sb) + assert len(overlaps) == 1, ( + "sub-feedrange [{}, {}) re-resolved to {} overlaps; the next " + "loop iteration would slice again".format( + sb.min, sb.max, len(overlaps))) + assert (scope.min, scope.max) == (sb.min, sb.max), ( + "sub-feedrange [{}, {}) routes to a partition whose scope is " + "[{}, {}); the EPK filter would over-fetch or " + "under-fetch".format(sb.min, sb.max, scope.min, scope.max)) + + def test_no_split_single_overlap_is_not_sliced(self): + # The "feed range fits inside one child" path: a feedrange that + # sits entirely inside a single child still resolves to one + # overlap. The slice branch is gated by `if len(overlapping) > 1`, + # so the call site goes straight to the single-partition POST. + # This is the negative control - verifying nothing in our slice + # helpers fires spuriously for the common safe case. + feed_range = routing_range.Range( + range_min="40", range_max="60", + isMinInclusive=True, isMaxInclusive=False) + overlaps, scope = _build_scope_from_overlaps( + [self._pkr("c1", "00", "80")], feed_range) + assert len(overlaps) == 1 + assert (scope.min, scope.max) == ("00", "80"), ( + "single-overlap re-resolution returned the wrong " + "partition scope") + + def test_three_child_split_slices_into_three(self): + # The 1->2 split is the common case, but nothing in the design + # caps it there - over enough wall-clock time, X1 and X2 can + # themselves split, and a saved feedrange from before all of + # those splits will resolve to 3+ overlaps. Pin that the slice + # handles N children the same way it handles 2: one + # sub-feedrange per child, in EPK order, covering the saved + # feedrange. + saved_feedrange = routing_range.Range( + range_min="00", range_max="FF", + isMinInclusive=True, isMaxInclusive=False) + children = [ + self._pkr("c1", "00", "55"), + self._pkr("c2", "55", "AA"), + self._pkr("c3", "AA", "FF"), + ] + sub_feedranges = _derive_initial_feedranges(saved_feedrange, children) + + assert len(sub_feedranges) == 3 + assert [(s.min, s.max) for s in sub_feedranges] == [ + ("00", "55"), ("55", "AA"), ("AA", "FF"), + ] + + +# ---------------------------------------------------------------------- # +# max_item_count normalization +# ---------------------------------------------------------------------- # +class TestNormalizeMaxItemCount: + """``_normalize_max_item_count`` collapses unset / non-numeric / non-positive + values to ``None`` (unbounded) and passes positive ints through unchanged. + + The pagination loop interprets ``None`` as "no client-side cap" and any + positive int as the per-page item limit. A zero or negative cap would make + the loop exit before issuing any POST while still emitting a + continuation token, leaving the caller in an empty-page-with-continuation + cycle - so those cases must be normalized to ``None``.""" + + def test_none_passes_through(self): + assert _normalize_max_item_count(None) is None + + def test_positive_int_passes_through(self): + assert _normalize_max_item_count(5) == 5 + + def test_positive_str_is_parsed(self): + assert _normalize_max_item_count("25") == 25 + + def test_zero_is_treated_as_unbounded(self): + assert _normalize_max_item_count(0) is None + + def test_negative_is_treated_as_unbounded(self): + assert _normalize_max_item_count(-1) is None + + def test_non_numeric_is_treated_as_unbounded(self): + assert _normalize_max_item_count("not-a-number") is None + + def test_object_is_treated_as_unbounded(self): + assert _normalize_max_item_count(object()) is None + + +# ---------------------------------------------------------------------- # +# Request-header shaping +# ---------------------------------------------------------------------- # +class TestApplyFeedrangeRequestHeaders: + """``_apply_feedrange_request_headers`` sets and clears routing/page/token + headers correctly for both full-partition and sub-range requests.""" + + @pytest.mark.parametrize( + "head_feedrange,expect_epk_headers", + [ + # full-partition request -> EPK headers must be cleared + (_mk_range("10", "20"), False), + # strict sub-range request -> EPK headers must be stamped + (_mk_range("12", "18"), True), + ], + ) + def test_epk_headers_match_full_vs_subrange(self, head_feedrange, expect_epk_headers): + req_headers = { + # pre-populate with stale values to prove clear behavior + http_constants.HttpHeaders.StartEpkString: "stale-start", + http_constants.HttpHeaders.EndEpkString: "stale-end", + } + overlapping = [{"id": "7", "minInclusive": "10", "maxExclusive": "20"}] + partition_scope = _mk_range("10", "20") + + _apply_feedrange_request_headers( + req_headers=req_headers, + overlapping=overlapping, + partition_scope=partition_scope, + head_feedrange=head_feedrange, + page_size_hint=None, + inbound_continuation=None, + ) + + assert req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] == "7" + assert req_headers[http_constants.HttpHeaders.ReadFeedKeyType] == "EffectivePartitionKeyRange" + if expect_epk_headers: + assert req_headers[http_constants.HttpHeaders.StartEpkString] == head_feedrange.min + assert req_headers[http_constants.HttpHeaders.EndEpkString] == head_feedrange.max + else: + assert http_constants.HttpHeaders.StartEpkString not in req_headers + assert http_constants.HttpHeaders.EndEpkString not in req_headers + + @pytest.mark.parametrize( + "page_size_hint,inbound_continuation,expect_page_size,expect_continuation", + [ + (5, "abc", True, True), + (None, "abc", False, True), + (5, None, True, False), + (None, None, False, False), + ], + ) + def test_page_size_and_continuation_are_set_or_cleared( + self, + page_size_hint, + inbound_continuation, + expect_page_size, + expect_continuation, + ): + req_headers = { + # pre-populate stale values; helper should clear when args are None + http_constants.HttpHeaders.PageSize: "999", + http_constants.HttpHeaders.Continuation: "stale-cont", + } + overlapping = [{"id": "9", "minInclusive": "30", "maxExclusive": "40"}] + partition_scope = _mk_range("30", "40") + head_feedrange = _mk_range("30", "40") + + _apply_feedrange_request_headers( + req_headers=req_headers, + overlapping=overlapping, + partition_scope=partition_scope, + head_feedrange=head_feedrange, + page_size_hint=page_size_hint, + inbound_continuation=inbound_continuation, + ) + + if expect_page_size: + assert req_headers[http_constants.HttpHeaders.PageSize] == str(page_size_hint) + else: + assert http_constants.HttpHeaders.PageSize not in req_headers + + if expect_continuation: + assert req_headers[http_constants.HttpHeaders.Continuation] == inbound_continuation + else: + assert http_constants.HttpHeaders.Continuation not in req_headers + + +class TestBudgetCounting: + """Page-item counting treats aggregate partial rows as merge fragments.""" + + def test_standard_documents_consume_page_item_limit(self): + partial_result = {"Documents": [{"id": "1"}, {"id": "2"}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT * FROM c") == 2 + + def test_multi_element_documents_in_aggregate_context_consume_page_item_limit(self): + partial_result = {"Documents": [{"_aggregate": {"count": 7}}, {"_aggregate": {"count": 3}}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT COUNT(1) FROM c") == 2 + + def test_object_aggregate_partial_does_not_consume_page_item_limit(self): + partial_result = {"Documents": [{"_aggregate": {"count": 7}}]} + assert _count_page_items_from_partial_result(partial_result, "SELECT COUNT(1) FROM c") == 0 + + def test_value_aggregate_partial_does_not_consume_page_item_limit(self): + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE COUNT(1) FROM c") == 0 + + def test_value_non_aggregate_numeric_row_consumes_page_item_limit(self): + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE c.value FROM c") == 1 + + def test_value_non_aggregate_boolean_row_consumes_page_item_limit(self): + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, "SELECT VALUE c.flag FROM c") == 1 + + +class TestAggregateMergeConsistency: + """Page-item counting and merge logic should classify aggregate fragments the same way.""" + + def test_value_count_boolean_fragments_are_not_treated_as_numeric_aggregates(self): + query = "SELECT VALUE COUNT(1) > 0 FROM c" + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, query) == 1 + + merged = _base._merge_query_results({"Documents": [True]}, {"Documents": [True]}, query) + assert merged["Documents"] == [True, True] + + def test_value_count_numeric_fragments_are_treated_as_aggregates(self): + query = "SELECT VALUE COUNT(1) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [10] + + def test_value_min_numeric_fragments_are_merged_with_min(self): + query = "SELECT VALUE MIN(c.score) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [3] + + def test_value_max_numeric_fragments_are_merged_with_max(self): + query = "SELECT VALUE MAX(c.score) FROM c" + partial_result = {"Documents": [7]} + assert _count_page_items_from_partial_result(partial_result, query) == 0 + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [7] + + def test_value_min_max_three_way_merge(self): + min_query = "SELECT VALUE MIN(c.score) FROM c" + merged_min = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, min_query) + merged_min = _base._merge_query_results(merged_min, {"Documents": [11]}, min_query) + assert merged_min["Documents"] == [3] + + max_query = "SELECT VALUE MAX(c.score) FROM c" + merged_max = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, max_query) + merged_max = _base._merge_query_results(merged_max, {"Documents": [11]}, max_query) + assert merged_max["Documents"] == [11] + + def test_value_boolean_non_aggregate_fragments_are_concatenated(self): + query = "SELECT VALUE c.flag FROM c" + partial_result = {"Documents": [True]} + assert _count_page_items_from_partial_result(partial_result, query) == 1 + + merged = _base._merge_query_results({"Documents": [True]}, {"Documents": [True]}, query) + assert merged["Documents"] == [True, True] + + def test_value_numeric_non_aggregate_fragments_are_concatenated(self): + """Regression: numeric VALUE rows must concatenate across partitions, not sum.""" + query = "SELECT VALUE c.score FROM c" + + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + assert _get_select_value_aggregate_function(query) is None + + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + assert merged["Documents"] == [7, 3] + + def test_value_float_non_aggregate_fragments_are_concatenated(self): + """Same regression for floats; they must concatenate, not collapse.""" + query = "SELECT VALUE c.ratio FROM c" + + assert _count_page_items_from_partial_result({"Documents": [1.5]}, query) == 1 + merged = _base._merge_query_results( + {"Documents": [1.5]}, {"Documents": [2.25]}, query, + ) + assert merged["Documents"] == [1.5, 2.25] + + def test_value_numeric_non_aggregate_three_way_merge_is_concatenated(self): + """Three-partition fan-in must preserve order and avoid numeric collapse.""" + query = "SELECT VALUE c.score FROM c" + merged = _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + merged = _base._merge_query_results(merged, {"Documents": [11]}, query) + assert merged["Documents"] == [7, 3, 11] + + def test_value_merge_raises_if_aggregate_function_detection_is_missing(self, monkeypatch): + query = "SELECT VALUE COUNT(1) FROM c" + monkeypatch.setattr(_base, "_get_select_value_aggregate_function", lambda _: None) + + with pytest.raises(ValueError) as excinfo: + _base._merge_query_results({"Documents": [7]}, {"Documents": [3]}, query) + + assert "VALUE aggregate classification" in str(excinfo.value) + + def test_value_avg_merge_raises_as_unsupported(self): + query = "SELECT VALUE AVG(c.value) FROM c" + + with pytest.raises(ValueError) as excinfo: + _base._merge_query_results({"Documents": [7.0]}, {"Documents": [3.0]}, query) + + assert "VALUE AVG aggregate merge" in str(excinfo.value) + + def test_raise_query_merge_value_error_rewrites_value_avg_message(self): + original = ValueError("VALUE AVG aggregate merge across partitions is not supported client-side.") + + with pytest.raises(ValueError) as excinfo: + _base._raise_query_merge_value_error(original) + + assert "SELECT VALUE AVG(...)" in str(excinfo.value) + assert "range-scoped pagination" in str(excinfo.value) + + def test_raise_query_merge_value_error_preserves_other_value_errors(self): + original = ValueError("Invariant violation: VALUE aggregate classification requires a recognized aggregate function.") + + with pytest.raises(ValueError) as excinfo: + _base._raise_query_merge_value_error(original) + + assert str(excinfo.value) == str(original) + + def test_value_aggregate_detection_allows_space_before_open_paren(self): + query = "SELECT VALUE COUNT (1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 0 + + def test_value_aggregate_detection_does_not_match_function_substrings(self): + query = "SELECT VALUE MYCOUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_value_aggregate_detection_ignores_subquery_aggregate_tokens(self): + query = "SELECT VALUE c.name FROM c WHERE EXISTS(SELECT VALUE COUNT(1) FROM d)" + assert _get_select_value_aggregate_function(query) is None + + def test_numeric_value_row_with_subquery_aggregate_still_consumes_page_item(self): + query = "SELECT VALUE c.id FROM c WHERE EXISTS(SELECT VALUE COUNT(1) FROM d)" + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_numeric_value_row_with_projection_subquery_aggregate_still_consumes_page_item(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + def test_numeric_value_row_with_array_projection_subquery_still_consumes_page_item(self): + query = "SELECT VALUE ARRAY(SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [7]}, query) == 1 + + +class TestSelectValueProjectionParser: + @pytest.mark.parametrize( + "normalized_query,expected_projection", + [ + ("SELECT VALUE COUNT(1) FROM C", "COUNT(1)"), + ("SELECT VALUE (SELECT VALUE COUNT(1) FROM D) FROM C", "(SELECT VALUE COUNT(1) FROM D)"), + ("SELECT VALUE C.FROMAGE FROM C", "C.FROMAGE"), + ("SELECT VALUE COUNT(1)", None), + ("SELECT VALUE (COUNT(1) FROM C", None), + ], + ) + def test_extract_outer_select_value_projection_edges(self, normalized_query, expected_projection): + assert _extract_outer_select_value_projection(normalized_query) == expected_projection + + def test_projection_level_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d) FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_projection_level_in_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE (SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_array_projection_subquery_is_not_classified_as_outer_aggregate(self): + query = "SELECT VALUE ARRAY(SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" + assert _get_select_value_aggregate_function(query) is None + + +class TestAggregateClassificationHeuristics: + def test_block_comment_prefix_does_not_drive_outer_select_value_detection(self): + query = "/* SELECT VALUE COUNT(1) */ SELECT VALUE c.x FROM c" + assert _get_select_value_aggregate_function(query) is None + + def test_value_aggregate_detected_with_comment_between_select_and_value(self): + query = "SELECT /* comment */ VALUE COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_value_aggregate_detected_with_comment_between_value_and_function(self): + query = "SELECT VALUE /* comment */ COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_comment_with_fake_from_does_not_truncate_projection(self): + query = "SELECT VALUE /* FROM d */ COUNT(1) FROM c" + assert _get_select_value_aggregate_function(query) == "COUNT" + + def test_block_comment_inside_string_literal_is_not_stripped(self): + query = "SELECT VALUE '/* COUNT(1) */' FROM c" + stripped = _strip_sql_block_comments(query) + assert "/* COUNT(1) */" in stripped + + def test_value_projection_with_property_named_count_is_not_aggregate(self): + query = "SELECT VALUE c.COUNT FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _count_page_items_from_partial_result({"Documents": [42.5]}, query) == 1 + + def test_classify_aggregate_partial_excludes_boolean_value_rows(self): + query = "SELECT VALUE COUNT(1) FROM c" + docs = [True] + assert _classify_aggregate_partial(docs, query) == _AggregatePartialClassification.NONE + + def test_classify_aggregate_partial_treats_non_aggregate_float_as_none(self): + query = "SELECT VALUE c.price FROM c" + docs = [42.5] + assert _classify_aggregate_partial(docs, query) == _AggregatePartialClassification.NONE + + +class TestEmptyPageStallCounter: + """No-progress guard only counts empty pages that still carry continuation.""" + + def test_increments_on_empty_page_with_continuation(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 3, + page_items_returned=0, + previous_feedrange=head_feedrange, + previous_backend_continuation="token", + head_feedrange=head_feedrange, + head_backend_continuation="token", + ) == 4 + + def test_increments_when_equal_bounds_are_different_objects(self): + # Guard against regressions where two equivalent ranges are reconstructed + # as distinct objects between loop iterations. + assert _update_no_progress_page_count( + 3, + page_items_returned=0, + previous_feedrange=_mk_range("10", "20"), + previous_backend_continuation="token", + head_feedrange=_mk_range("10", "20"), + head_backend_continuation="token", + ) == 4 + + def test_resets_when_items_are_returned(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 5, + page_items_returned=1, + previous_feedrange=head_feedrange, + previous_backend_continuation="token", + head_feedrange=head_feedrange, + head_backend_continuation="token", + ) == 0 + + def test_resets_when_continuation_is_none(self): + assert _update_no_progress_page_count( + _MAX_CONSECUTIVE_NO_PROGRESS_PAGES - 1, + page_items_returned=0, + previous_feedrange=_mk_range("10", "20"), + previous_backend_continuation="token", + head_feedrange=_mk_range("20", "30"), + head_backend_continuation=None, + ) == 0 + + def test_resets_when_continuation_advances(self): + head_feedrange = _mk_range("10", "20") + assert _update_no_progress_page_count( + 8, + page_items_returned=0, + previous_feedrange=head_feedrange, + previous_backend_continuation="token-1", + head_feedrange=head_feedrange, + head_backend_continuation="token-2", + ) == 0 + + +class TestExplodeIterationGuard: + def test_increments_under_limit(self): + assert _increment_explode_iterations_or_raise(0) == 1 + + def test_raises_over_limit(self): + with pytest.raises(RuntimeError) as excinfo: + _increment_explode_iterations_or_raise(_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS) + assert "split re-resolution" in str(excinfo.value) + + +class TestFeedRangePaginationState: + """Unit tests for the shared pagination state machine. + + The state machine holds a single ordered queue of + ``(sub-range, backend continuation)`` pairs — modeled on Java's + ``Queue``. There is no "current vs. + remaining" split; the head is just ``queue[0]`` and the tail is + ``queue[1:]`` by virtue of being later in the same deque. + """ + + @staticmethod + def _pkr(pkr_id: str, mn: str, mx: str) -> dict: + return {"id": pkr_id, "minInclusive": mn, "maxExclusive": mx} + + @staticmethod + def _bounds(rng: routing_range.Range) -> tuple[str, str]: + return rng.min, rng.max + + @classmethod + def _queue_bounds(cls, state) -> list: + """All ``(min, max, bc)`` triples in the queue, in head-first order.""" + return [(r.min, r.max, bc) for r, bc in state.queue] + + def test_from_derived_feedranges_empty_initializes_done_state(self): + state = _FeedRangePaginationState.from_derived_feedranges([], page_size_hint=5) + assert list(state.queue) == [] + assert state.head_range is None + assert state.head_bc is None + assert state.page_size_hint == 5 + + def test_from_derived_feedranges_seeds_queue_with_no_continuations(self): + a = _mk_range("00", "40") + b = _mk_range("40", "80") + state = _FeedRangePaginationState.from_derived_feedranges([a, b], page_size_hint=7) + assert self._queue_bounds(state) == [("00", "40", None), ("40", "80", None)] + assert self._bounds(state.head_range) == ("00", "40") + assert state.head_bc is None + assert state.page_size_hint == 7 + + def test_from_single_feedrange_with_continuation_seeds_head_bc(self): + head = _mk_range("AA", "BB") + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + head, + "legacy-token-1", + page_size_hint=11, + ) + assert self._queue_bounds(state) == [("AA", "BB", "legacy-token-1")] + assert self._bounds(state.head_range) == ("AA", "BB") + assert state.head_bc == "legacy-token-1" + assert state.page_size_hint == 11 + + def test_from_single_feedrange_with_continuation_allows_null(self): + head = _mk_range("AA", "BB") + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + head, + None, + page_size_hint=None, + ) + assert self._queue_bounds(state) == [("AA", "BB", None)] + assert state.head_bc is None + + @pytest.mark.parametrize( + "queue,page_size_hint,expected", + [ + ([], None, False), + ([(_mk_range("00", "40"), None)], 0, True), + ([(_mk_range("00", "40"), None)], -1, True), + ([(_mk_range("00", "40"), None)], None, True), + ([(_mk_range("00", "40"), None)], 1, True), + ], + ) + def test_can_issue_request_boundaries(self, queue, page_size_hint, expected): + state = _FeedRangePaginationState( + queue=queue, + page_size_hint=page_size_hint, + ) + assert state.can_issue_request() is expected + + def test_from_inbound_parses_queue_and_continuations(self): + # The wire format is a single ordered ``c`` list. The state + # machine loads it as a single queue of ``(range, bc)`` pairs + # — ``c[0]`` becomes the head, with no privileged "current vs. + # remaining" split. + inbound = { + _FIELD_CONTINUATIONS: [ + {"min": "00", "max": "40", _FIELD_BACKEND_CONTINUATION: "token-1"}, + {"min": "40", "max": "80", _FIELD_BACKEND_CONTINUATION: None}, + ], + } + state = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=9) + assert self._queue_bounds(state) == [ + ("00", "40", "token-1"), + ("40", "80", None), + ] + assert self._bounds(state.head_range) == ("00", "40") + assert state.head_bc == "token-1" + assert state.page_size_hint == 9 + + def test_from_inbound_preserves_per_entry_backend_continuations(self): + # Future non-sequential / parallel-loop case: a saved token + # where multiple ``c[i]`` entries each carry their own non-null + # backend continuation. All must round-trip into the queue + # untouched. + inbound = { + _FIELD_CONTINUATIONS: [ + {"min": "00", "max": "40", _FIELD_BACKEND_CONTINUATION: "B-cont-5"}, + {"min": "40", "max": "80", _FIELD_BACKEND_CONTINUATION: "A-cont-5"}, + ], + } + state = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=None) + assert self._queue_bounds(state) == [ + ("00", "40", "B-cont-5"), + ("40", "80", "A-cont-5"), + ] + + # When the loop drains the head slice, the next entry's saved + # backend continuation is naturally exposed as the new head_bc. + state.apply_post_result(items_returned=0, backend_continuation=None) + assert self._bounds(state.head_range) == ("40", "80") + assert state.head_bc == "A-cont-5" + assert self._queue_bounds(state) == [("40", "80", "A-cont-5")] + + def test_apply_post_result_with_continuation_updates_head_bc_in_place(self): + current = _mk_range("00", "40") + next_range = _mk_range("40", "80") + state = _FeedRangePaginationState( + queue=[(current, None), (next_range, None)], + page_size_hint=5, + ) + + state.apply_post_result(items_returned=2, backend_continuation="token-2") + + assert self._queue_bounds(state) == [ + ("00", "40", "token-2"), + ("40", "80", None), + ] + assert state.head_bc == "token-2" + assert state.page_size_hint == 5 + + def test_apply_post_result_with_none_pops_head(self): + current = _mk_range("00", "40") + next_range = _mk_range("40", "80") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (next_range, None)], + page_size_hint=6, + ) + + state.apply_post_result(items_returned=1, backend_continuation=None) + + assert self._queue_bounds(state) == [("40", "80", None)] + assert self._bounds(state.head_range) == ("40", "80") + assert state.head_bc is None + assert state.page_size_hint == 6 + + def test_apply_post_result_with_none_and_no_tail_drains_queue(self): + current = _mk_range("00", "40") + state = _FeedRangePaginationState( + queue=[(current, "token-1")], + page_size_hint=None, + ) + + state.apply_post_result(items_returned=1, backend_continuation=None) + + assert list(state.queue) == [] + assert state.head_range is None + assert state.head_bc is None + + def test_explode_on_multi_overlap_single_overlap_keeps_state(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap([self._pkr("X", "00", "80")]) + + assert did_explode is False + assert self._queue_bounds(state) == [ + ("00", "80", "token-1"), + ("80", "C0", None), + ] + + def test_explode_on_multi_overlap_replaces_head_with_children(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, "token-1"), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap( + [ + self._pkr("X1", "00", "40"), + self._pkr("X2", "40", "80"), + ] + ) + + assert did_explode is True + # Parent is dequeued and children are appended at the queue tail + # in EPK order (Java/.NET parity). Children inherit the parent's + # backend continuation so resume progress is preserved across split. + assert self._queue_bounds(state) == [ + ("80", "C0", None), + ("00", "40", "token-1"), + ("40", "80", "token-1"), + ] + assert state.head_bc is None + + def test_explode_on_multi_overlap_with_no_parent_continuation_keeps_children_none(self): + current = _mk_range("00", "80") + tail = _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(current, None), (tail, None)], + page_size_hint=4, + ) + + did_explode = state.explode_on_multi_overlap( + [ + self._pkr("X1", "00", "40"), + self._pkr("X2", "40", "80"), + ] + ) + + assert did_explode is True + assert self._queue_bounds(state) == [ + ("80", "C0", None), + ("00", "40", None), + ("40", "80", None), + ] + + +class TestCheckpointRoundTripOnException: + """If the per-iteration POST raises mid-page, the call site stamps + the current pagination state into the outbound continuation header + before re-raising. That checkpoint must round-trip so the caller + can retry from exactly where the failed POST left off — never from + the start of the head sub-range, never skipping it. + + These tests drive the state machine + token codec end-to-end + without standing up a live client: emit the checkpoint that the + sync/async loops would emit on exception, then resume from it and + assert the queue is intact. + """ + + @staticmethod + def _bounds(rng: routing_range.Range) -> tuple[str, str]: + return rng.min, rng.max + + def test_checkpoint_emitted_mid_page_resumes_at_same_head(self): + # Simulate: queue is [A (in-flight, bc=cont-A), B, C]; the POST + # for A returned cont-A successfully on a previous iteration but + # the next POST (still on A) is about to raise. The call site + # writes the outbound continuation BEFORE re-raising — that + # token must put us back at (A, cont-A) on resume, with B and C + # still queued behind it untouched. + a, b, c = _mk_range("00", "40"), _mk_range("40", "80"), _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(a, "cont-A"), (b, None), (c, None)], + page_size_hint=7, + ) + + headers: dict = {} + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + # The header carries an opaque, base64-encoded v=1 envelope. + wire = headers[http_constants.HttpHeaders.Continuation] + assert isinstance(wire, str) and wire + + # Caller resumes from that exact wire token. + inbound = _decode_token(wire) + assert inbound is not None + _validate_token_identity(inbound, _RID, _QUERY, _FEED_RANGE) + resumed = _FeedRangePaginationState.from_inbound(inbound, page_size_hint=7) + + # Head is still A with its bc; tail (B, C) is intact and bc-free. + assert self._bounds(resumed.head_range) == ("00", "40") + assert resumed.head_bc == "cont-A" + assert [(r.min, r.max, bc) for r, bc in resumed.queue] == [ + ("00", "40", "cont-A"), + ("40", "80", None), + ("80", "C0", None), + ] + + def test_checkpoint_after_partial_drain_resumes_at_next_head(self): + # Simulate: A was fully drained (apply_post_result with bc=None + # popped it) and the POST for B is about to raise. The + # checkpoint must put us at (B, None) on resume with C still + # queued, and A must NOT reappear. + a, b, c = _mk_range("00", "40"), _mk_range("40", "80"), _mk_range("80", "C0") + state = _FeedRangePaginationState( + queue=[(a, "cont-A-final"), (b, None), (c, None)], + page_size_hint=5, + ) + state.apply_post_result(items_returned=3, backend_continuation=None) + # A is now drained; head should be B. + assert self._bounds(state.head_range) == ("40", "80") + + headers: dict = {} + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + resumed = _FeedRangePaginationState.from_inbound( + _decode_token(headers[http_constants.HttpHeaders.Continuation]), + page_size_hint=2, + ) + assert [(r.min, r.max, bc) for r, bc in resumed.queue] == [ + ("40", "80", None), + ("80", "C0", None), + ] + + def test_drained_state_clears_continuation_header(self): + # When the entire queue is drained, the call site must clear + # the outbound continuation header (not leave a stale one + # behind) so the caller's by_page loop terminates. + headers = {http_constants.HttpHeaders.Continuation: "stale-from-prior-page"} + state = _FeedRangePaginationState(queue=[], page_size_hint=None) + + state.write_outbound_continuation(headers, _RID, _QUERY, _FEED_RANGE) + + assert http_constants.HttpHeaders.Continuation not in headers + diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index 5564136002b7..32d2aa0399e0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -14,7 +14,8 @@ from azure.cosmos import exceptions from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders + # tracemalloc is not available in PyPy, so we import conditionally try: @@ -38,6 +39,7 @@ class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): self._global_endpoint_manager = MockGlobalEndpointManager() + self.last_response_headers = {} self.refresh_routing_map_provider_call_count = 0 def refresh_routing_map_provider(self): @@ -152,6 +154,167 @@ def mock_fetch_function(options): "refresh_routing_map_provider should be called once on 410" assert result == expected_docs, "Should return expected documents after retry" + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_uses_checkpoint_continuation_from_internal_capture(self, mock_execute): + """410 retry should resume from checkpoint continuation stamped by __QueryFeed.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["checkpoint-token"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_ignores_stale_shared_client_headers(self, mock_execute): + """Retry resumes from request-local captured headers, not shared client headers.""" + mock_client = MockClient() + mock_client.last_response_headers = {HttpHeaders.Continuation: "stale-global-token"} + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "fresh-checkpoint" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["fresh-checkpoint"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_without_checkpoint_continuation_retries_from_none(self, mock_execute): + """If no checkpoint header is stamped, continuation should remain None on retry.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture.clear() + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == [None] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_multiple_410_uses_latest_checkpoint_continuation(self, mock_execute): + """Across repeated 410 retries, execution should resume using the latest checkpoint token.""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-1" + raise create_410_partition_split_error() + if call_count[0] == 2: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-2" + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 3 + assert seen_continuations == ["checkpoint-token-2"] + assert result == expected_docs + + @patch('azure.cosmos._retry_utility.Execute') + def test_mid_pagination_split_retries_from_checkpoint_without_duplicates(self, mock_execute): + """Simulate page2 split and verify retry resumes from checkpoint token, not from page1.""" + mock_client = MockClient() + + docs_page_1 = [{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}, {"id": "5"}] + docs_page_2 = [{"id": "6"}, {"id": "7"}, {"id": "8"}, {"id": "9"}, {"id": "10"}] + + def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + return callback() + + mock_execute.side_effect = execute_side_effect + + fetch_calls = [] + + def mock_fetch_function(options): + continuation = options.get("continuation") + fetch_calls.append(continuation) + + if continuation is None: + return (docs_page_1, {HttpHeaders.Continuation: "token-after-page-1"}) + + if continuation == "token-after-page-1": + # Simulate __QueryFeed writing a checkpoint before re-raising split error. + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-after-split" + raise create_410_partition_split_error() + + if continuation == "checkpoint-after-split": + return (docs_page_2, {}) + + self.fail(f"Unexpected continuation seen by fetch: {continuation}") + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + first_result = context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(first_result, docs_page_1) + + second_result = context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(second_result, docs_page_2) + + # Validate the second page did not replay page-1 items and resumed from checkpoint. + self.assertEqual(fetch_calls, [None, "token-after-page-1", "checkpoint-after-split"]) + @patch('azure.cosmos._retry_utility.Execute') def test_pk_range_query_skips_410_retry_to_prevent_recursion(self, mock_execute): """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index a487fc6a85eb..8d09fb7fcd31 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -13,7 +13,7 @@ import pytest from azure.cosmos import exceptions -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext @@ -39,6 +39,7 @@ class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): self._global_endpoint_manager = MockGlobalEndpointManager() + self.last_response_headers = {} self.refresh_routing_map_provider_call_count = 0 def refresh_routing_map_provider(self): @@ -151,6 +152,167 @@ async def mock_fetch_function(options): "refresh_routing_map_provider should be called once on 410" assert result == expected_docs, "Should return expected documents after retry" + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_uses_checkpoint_continuation_from_internal_capture_async(self, mock_execute): + """410 retry should resume from checkpoint continuation stamped by __QueryFeed (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["checkpoint-token"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_ignores_stale_shared_client_headers_async(self, mock_execute): + """Retry resumes from request-local captured headers, not shared client headers.""" + mock_client = MockClient() + mock_client.last_response_headers = {HttpHeaders.Continuation: "stale-global-token"} + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + context = _DefaultQueryExecutionContext(mock_client, {}, lambda _options: (expected_docs, {})) + + async def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "fresh-checkpoint" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context._fetch_function = mock_fetch_function + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == ["fresh-checkpoint"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_without_checkpoint_continuation_retries_from_none_async(self, mock_execute): + """If no checkpoint header is stamped, continuation should remain None on retry (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture.clear() + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2 + assert seen_continuations == [None] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_multiple_410_uses_latest_checkpoint_continuation_async(self, mock_execute): + """Across repeated 410 retries, execution should resume using the latest checkpoint token (async).""" + mock_client = MockClient() + expected_docs = [{"id": "success"}] + seen_continuations = [] + call_count = [0] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-1" + raise create_410_partition_split_error() + if call_count[0] == 2: + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-token-2" + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + seen_continuations.append(options.get("continuation")) + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 3 + assert seen_continuations == ["checkpoint-token-2"] + assert result == expected_docs + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_mid_pagination_split_retries_from_checkpoint_without_duplicates_async(self, mock_execute): + """Simulate page2 split and verify async retry resumes from checkpoint token, not from page1.""" + mock_client = MockClient() + + docs_page_1 = [{"id": "1"}, {"id": "2"}, {"id": "3"}, {"id": "4"}, {"id": "5"}] + docs_page_2 = [{"id": "6"}, {"id": "7"}, {"id": "8"}, {"id": "9"}, {"id": "10"}] + + async def execute_side_effect(client, _global_endpoint_manager, callback, **kwargs): + return await callback() + + mock_execute.side_effect = execute_side_effect + + fetch_calls = [] + + async def mock_fetch_function(options): + continuation = options.get("continuation") + fetch_calls.append(continuation) + + if continuation is None: + return (docs_page_1, {HttpHeaders.Continuation: "token-after-page-1"}) + + if continuation == "token-after-page-1": + # Simulate __QueryFeed writing a checkpoint before re-raising split error. + context._internal_response_headers_capture[HttpHeaders.Continuation] = "checkpoint-after-split" + raise create_410_partition_split_error() + + if continuation == "checkpoint-after-split": + return (docs_page_2, {}) + + self.fail(f"Unexpected continuation seen by fetch: {continuation}") + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + first_result = await context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(first_result, docs_page_1) + + second_result = await context._fetch_items_helper_with_retries(mock_fetch_function) + self.assertListEqual(second_result, docs_page_2) + + # Validate the second page did not replay page-1 items and resumed from checkpoint. + self.assertEqual(fetch_calls, [None, "token-after-page-1", "checkpoint-after-split"]) + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') async def test_pk_range_query_skips_410_retry_to_prevent_recursion_async(self, mock_execute): """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 4970723e7757..ae38147fb092 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -4,6 +4,8 @@ import os import unittest import uuid +from contextlib import contextmanager +from unittest.mock import patch import pytest @@ -12,6 +14,7 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import http_constants, DatabaseProxy, _endpoint_discovery_retry_policy +from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType @@ -41,6 +44,19 @@ def setUpClass(cls): if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" + @contextmanager + def _new_client_with_structured_full_pk_env(self, value: str): + use_multiple_write_locations = os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True" + with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": value}, clear=False): + with cosmos_client.CosmosClient( + self.host, + self.credential, + multiple_write_locations=use_multiple_write_locations, + ) as client: + database = client.get_database_client(self.TEST_DATABASE_ID) + created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + yield client, created_collection + def test_first_and_last_slashes_trimmed_for_query_string(self): created_collection = self.created_db.create_container( "test_trimmed_slashes", PartitionKey(path="/pk")) @@ -469,6 +485,385 @@ def test_cross_partition_query_with_continuation_token(self): self.assertEqual(second_page['id'], second_page_fetched_with_continuation_token['id']) + def test_full_pk_continuation_emits_legacy_by_default(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + + self.assertIsNotNone(token) + self.assertIsNone(_decode_token(token)) + + def test_full_pk_continuation_emits_structured_with_env_var(self): + with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + + self.assertIsNotNone(token) + self.assertIsNotNone(_decode_token(token)) + + def test_full_pk_continuation_emits_structured_with_env_var_and_new_client(self): + with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": "true"}, clear=False): + with cosmos_client.CosmosClient( + self.host, + self.credential, + ) as client: + database = client.get_database_client(self.TEST_DATABASE_ID) + created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + + self.assertIsNotNone(token) + self.assertIsNotNone(_decode_token(token)) + + def test_full_pk_legacy_replay_resumes_same_page(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + second_page = list(pager.next())[0] + + self.assertIsNotNone(token) + self.assertIsNone(_decode_token(token)) + + replay_pager = query_iterable.by_page(token) + replay_second_page = list(replay_pager.next())[0] + self.assertEqual(second_page['id'], replay_second_page['id']) + + def test_full_pk_structured_replay_resumes_same_page(self): + with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + second_page = list(pager.next())[0] + + self.assertIsNotNone(token) + self.assertIsNotNone(_decode_token(token)) + + replay_pager = query_iterable.by_page(token) + replay_second_page = list(replay_pager.next())[0] + self.assertEqual(second_page['id'], replay_second_page['id']) + + def test_full_pk_structured_replay_rejects_query_mismatch(self): + with self._new_client_with_structured_full_pk_env("on") as (_, created_collection): + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + source_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + source_pager = source_iterable.by_page() + source_pager.next() + token = source_pager.continuation_token + self.assertIsNotNone(_decode_token(token)) + + mismatched_query_iterable = created_collection.query_items( + query='SELECT VALUE c.id from c', + partition_key='pk', + max_item_count=1, + ) + with self.assertRaisesRegex(ValueError, 'query hash mismatch'): + mismatched_query_iterable.by_page(token).next() + + def test_full_pk_structured_replay_rejects_partition_key_mismatch(self): + with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk2', 'id': str(uuid.uuid4())}) + source_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + source_pager = source_iterable.by_page() + source_pager.next() + token = source_pager.continuation_token + self.assertIsNotNone(_decode_token(token)) + + mismatched_pk_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk2', + max_item_count=1, + ) + with self.assertRaisesRegex(ValueError, 'feed_range hash mismatch'): + mismatched_pk_iterable.by_page(token).next() + + def test_mixed_version_structured_token_replayed_by_legacy_mode(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + new_mode_pager = new_mode_iterable.by_page() + new_mode_pager.next() + structured_token = new_mode_pager.continuation_token + self.assertIsNotNone(_decode_token(structured_token)) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page(structured_token) + list(legacy_mode_pager.next()) + resumed_continuation = legacy_mode_pager.continuation_token + self.assertIsNotNone(resumed_continuation) + self.assertIsNone(_decode_token(resumed_continuation)) + + def test_mixed_version_legacy_token_replayed_by_structured_mode(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + legacy_mode_pager.next() + legacy_token = legacy_mode_pager.continuation_token + self.assertIsNone(_decode_token(legacy_token)) + + with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + new_mode_pager = new_mode_iterable.by_page(legacy_token) + list(new_mode_pager.next()) + resumed_continuation = new_mode_pager.continuation_token + self.assertIsNotNone(resumed_continuation) + self.assertIsNotNone(_decode_token(resumed_continuation)) + + def test_full_pk_split_during_page_resets_retry_state(self): + pk_value = 'pk-' + str(uuid.uuid4()) + inserted_ids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())] + with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): + for doc_id in inserted_ids: + created_collection.upsert_item(body={'pk': pk_value, 'id': doc_id}) + query_iterable = created_collection.query_items( + query='SELECT * from c ORDER BY c.id', + partition_key=pk_value, + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + continuation_token = pager.continuation_token + pager.next() + + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + injected_split = False + + def _split_once_post(*args, **kwargs): + nonlocal injected_split + req_headers = args[3] + if ( + not injected_split + and req_headers.get(http_constants.HttpHeaders.Continuation) + ): + injected_split = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.GONE, + sub_status=http_constants.SubStatusCodes.PARTITION_KEY_RANGE_GONE, + message='simulated split during full-pk page fetch', + ) + return original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _split_once_post + try: + replay_pager = query_iterable.by_page(continuation_token) + replay_second_page = list(replay_pager.next())[0] + self.assertTrue(injected_split) + self.assertIn(replay_second_page['id'], inserted_ids) + self.assertIsNotNone(_decode_token(replay_pager.continuation_token)) + finally: + client_conn._CosmosClientConnection__Post = original_post + + def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + legacy_mode_pager.next() + legacy_token = legacy_mode_pager.continuation_token + self.assertIsNone(_decode_token(legacy_token)) + + with self._new_client_with_structured_full_pk_env("on") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + raise RuntimeError("bridge-runtime-error") + return original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaisesRegex(RuntimeError, 'bridge-runtime-error'): + new_mode_iterable.by_page(legacy_token).next() + finally: + client_conn._CosmosClientConnection__Post = original_post + + def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + legacy_mode_pager.next() + legacy_token = legacy_mode_pager.continuation_token + self.assertIsNone(_decode_token(legacy_token)) + + with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + saw_legacy_fallback_headers = False + + def _failing_post(*args, **kwargs): + nonlocal saw_legacy_fallback_headers + req_headers = args[3] + if ( + http_constants.HttpHeaders.PartitionKey in req_headers + and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token + and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers + ): + saw_legacy_fallback_headers = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, + message="throttled", + ) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaises(exceptions.CosmosHttpResponseError): + new_mode_iterable.by_page(legacy_token).next() + self.assertFalse(saw_legacy_fallback_headers) + finally: + client_conn._CosmosClientConnection__Post = original_post + + def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + legacy_mode_pager.next() + legacy_token = legacy_mode_pager.continuation_token + self.assertIsNone(_decode_token(legacy_token)) + + with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.BAD_REQUEST, + message="legacy bridge compatibility failure", + ) + raise RuntimeError("fallback-post-failed") + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaisesRegex(RuntimeError, 'fallback-post-failed'): + new_mode_iterable.by_page(legacy_token).next() + self.assertEqual(post_call_count, 2) + continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + self.assertIsNotNone(continuation) + self.assertIsNotNone(_decode_token(continuation)) + finally: + client_conn._CosmosClientConnection__Post = original_post + def test_cross_partition_query_with_none_partition_key(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) document_definition = {'pk': 'pk1', 'id': str(uuid.uuid4())} diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 40207abddf14..fd3d65f1bfbd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -4,7 +4,9 @@ import os import unittest import uuid +from contextlib import asynccontextmanager from asyncio import gather +from unittest.mock import patch import pytest @@ -12,6 +14,7 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import http_constants, _endpoint_discovery_retry_policy +from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos._retry_options import RetryOptions from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy @@ -55,6 +58,18 @@ async def asyncSetUp(self): async def asyncTearDown(self): await self.client.close() + @asynccontextmanager + async def _new_client_with_structured_full_pk_env(self, value: str): + with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": value}, clear=False): + async with CosmosClient( + self.host, + self.masterKey, + multiple_write_locations=self.use_multiple_write_locations, + ) as client: + database = client.get_database_client(self.TEST_DATABASE_ID) + created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + yield client, created_collection + async def test_first_and_last_slashes_trimmed_for_query_string_async(self): created_collection = await self.created_db.create_container( str(uuid.uuid4()), PartitionKey(path="/pk")) @@ -465,6 +480,387 @@ async def test_cross_partition_query_with_continuation_token_async(self): assert second_page['id'] == second_page_fetched_with_continuation_token['id'] + async def test_full_pk_continuation_emits_legacy_by_default_async(self): + """Full partition-key queries return legacy continuation tokens by default.""" + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + + assert token is not None + assert _decode_token(token) is None + + async def test_full_pk_continuation_emits_structured_with_env_var_async(self): + """Enabling the environment variable returns structured continuation tokens.""" + async with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + + assert token is not None + assert _decode_token(token) is not None + + async def test_full_pk_continuation_emits_structured_with_env_var_and_new_client_async(self): + """The environment variable is read when the client is created.""" + with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": "true"}, clear=False): + async with CosmosClient( + self.host, + self.masterKey, + ) as client: + database = client.get_database_client(self.TEST_DATABASE_ID) + created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + + assert token is not None + assert _decode_token(token) is not None + + async def test_full_pk_legacy_replay_resumes_same_page_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + second_page = [item async for item in await pager.__anext__()][0] + + assert token is not None + assert _decode_token(token) is None + + replay_pager = query_iterable.by_page(token) + replay_second_page = [item async for item in await replay_pager.__anext__()][0] + assert second_page['id'] == replay_second_page['id'] + + async def test_full_pk_structured_replay_resumes_same_page_async(self): + async with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + query_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + token = pager.continuation_token + second_page = [item async for item in await pager.__anext__()][0] + replay_pager = query_iterable.by_page(token) + replay_second_page = [item async for item in await replay_pager.__anext__()][0] + assert token is not None + assert _decode_token(token) is not None + assert second_page['id'] == replay_second_page['id'] + + async def test_full_pk_structured_replay_rejects_query_mismatch_async(self): + async with self._new_client_with_structured_full_pk_env("on") as (_, created_collection): + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + source_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + source_pager = source_iterable.by_page() + await source_pager.__anext__() + token = source_pager.continuation_token + assert _decode_token(token) is not None + + mismatched_query_iterable = created_collection.query_items( + query='SELECT VALUE c.id from c', + partition_key='pk', + max_item_count=1, + ) + with pytest.raises(ValueError, match='query hash mismatch'): + await mismatched_query_iterable.by_page(token).__anext__() + + async def test_full_pk_structured_replay_rejects_partition_key_mismatch_async(self): + async with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk2', 'id': str(uuid.uuid4())}) + source_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + source_pager = source_iterable.by_page() + await source_pager.__anext__() + token = source_pager.continuation_token + assert _decode_token(token) is not None + + mismatched_pk_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk2', + max_item_count=1, + ) + with pytest.raises(ValueError, match='feed_range hash mismatch'): + await mismatched_pk_iterable.by_page(token).__anext__() + + async def test_mixed_version_structured_token_replayed_by_legacy_mode_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + async with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + new_mode_pager = new_mode_iterable.by_page() + await new_mode_pager.__anext__() + structured_token = new_mode_pager.continuation_token + assert _decode_token(structured_token) is not None + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page(structured_token) + await legacy_mode_pager.__anext__() + resumed_continuation = legacy_mode_pager.continuation_token + assert resumed_continuation is not None + assert _decode_token(resumed_continuation) is None + + async def test_mixed_version_legacy_token_replayed_by_structured_mode_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + await legacy_mode_pager.__anext__() + legacy_token = legacy_mode_pager.continuation_token + assert _decode_token(legacy_token) is None + + async with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + new_mode_pager = new_mode_iterable.by_page(legacy_token) + await new_mode_pager.__anext__() + resumed_continuation = new_mode_pager.continuation_token + assert resumed_continuation is not None + assert _decode_token(resumed_continuation) is not None + + async def test_full_pk_split_during_page_resets_retry_state_async(self): + pk_value = 'pk-' + str(uuid.uuid4()) + inserted_ids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())] + + async with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): + for doc_id in inserted_ids: + await created_collection.upsert_item(body={'pk': pk_value, 'id': doc_id}) + query_iterable = created_collection.query_items( + query='SELECT * from c ORDER BY c.id', + partition_key=pk_value, + max_item_count=1, + ) + pager = query_iterable.by_page() + await pager.__anext__() + continuation_token = pager.continuation_token + await pager.__anext__() + + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + injected_split = False + + async def _split_once_post(*args, **kwargs): + nonlocal injected_split + req_headers = args[3] + if ( + not injected_split + and req_headers.get(http_constants.HttpHeaders.Continuation) + ): + injected_split = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.GONE, + sub_status=http_constants.SubStatusCodes.PARTITION_KEY_RANGE_GONE, + message='simulated split during full-pk page fetch async', + ) + return await original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _split_once_post + try: + replay_pager = query_iterable.by_page(continuation_token) + replay_second_page = [item async for item in await replay_pager.__anext__()][0] + assert injected_split + assert replay_second_page['id'] in inserted_ids + assert _decode_token(replay_pager.continuation_token) is not None + finally: + client_conn._CosmosClientConnection__Post = original_post + + async def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + await legacy_mode_pager.__anext__() + legacy_token = legacy_mode_pager.continuation_token + assert _decode_token(legacy_token) is None + + async with self._new_client_with_structured_full_pk_env("on") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + async def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + raise RuntimeError("bridge-runtime-error-async") + return await original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(RuntimeError, match='bridge-runtime-error-async'): + await new_mode_iterable.by_page(legacy_token).__anext__() + finally: + client_conn._CosmosClientConnection__Post = original_post + + async def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + await legacy_mode_pager.__anext__() + legacy_token = legacy_mode_pager.continuation_token + assert _decode_token(legacy_token) is None + + async with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + saw_legacy_fallback_headers = False + + async def _failing_post(*args, **kwargs): + nonlocal saw_legacy_fallback_headers + req_headers = args[3] + if ( + http_constants.HttpHeaders.PartitionKey in req_headers + and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token + and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers + ): + saw_legacy_fallback_headers = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, + message="throttled", + ) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(exceptions.CosmosHttpResponseError): + await new_mode_iterable.by_page(legacy_token).__anext__() + assert not saw_legacy_fallback_headers + finally: + client_conn._CosmosClientConnection__Post = original_post + + async def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint_async(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + + legacy_mode_iterable = created_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + legacy_mode_pager = legacy_mode_iterable.by_page() + await legacy_mode_pager.__anext__() + legacy_token = legacy_mode_pager.continuation_token + assert _decode_token(legacy_token) is None + + async with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): + new_mode_iterable = structured_collection.query_items( + query='SELECT * from c', + partition_key='pk', + max_item_count=1, + ) + client_conn = structured_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + async def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.BAD_REQUEST, + message="legacy bridge compatibility failure", + ) + raise RuntimeError("fallback-post-failed-async") + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(RuntimeError, match='fallback-post-failed-async'): + await new_mode_iterable.by_page(legacy_token).__anext__() + assert post_call_count == 2 + continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert continuation is not None + assert _decode_token(continuation) is not None + finally: + client_conn._CosmosClientConnection__Post = original_post + async def test_cross_partition_query_with_none_partition_key_async(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) document_definition = {'pk': 'pk1', 'id': str(uuid.uuid4())} diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py new file mode 100644 index 000000000000..8af9f0517310 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py @@ -0,0 +1,1005 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""End-to-end tests for ``query_items(feed_range=...)`` against a feed_range +that overlaps more than one physical partition. + +These tests pin two invariants of the multi-overlap pagination contract: + +* every page returns at most ``max_item_count`` items, and +* no item id is returned on more than one page (no duplicates across the + fan-out / resume boundary). + +Three scenarios are covered: + +* ``test_two_partition_feed_range`` — feed_range overlaps two adjacent + physical partitions. +* ``test_three_way_overlap`` — feed_range overlaps three adjacent physical + partitions; wider fan-out. +* ``test_post_split_resume`` — emit a continuation under one physical + layout, force a real partition split, then resume with the same + continuation against the new layout. Slow (``cosmosSplit``). + + +Async parity lives in ``test_query_feed_range_multipartition_async.py``. +""" + +import time +import unittest +import uuid +from typing import Iterable, List, Optional, Tuple + +import pytest + +import test_config +from azure.cosmos import _base +from azure.cosmos import CosmosClient, documents, http_constants +from azure.cosmos._routing.feed_range_continuation import _decode_token +from azure.cosmos.partition_key import PartitionKey + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID + +# Dedicated container for these tests. Throughput is chosen so the +# container has multiple physical partitions out of the box (no split needed +# for the two/three-overlap tests); the post-split test then forces an +# additional split on top. +REPRO_CONTAINER_ID = "FeedRangeMultiPartition-" + str(uuid.uuid4()) +REPRO_PARTITION_KEY = "pk" +REPRO_THROUGHPUT = CONFIG.THROUGHPUT_FOR_5_PARTITIONS # 30000 → ~5 partitions +REPRO_DOC_COUNT = 200 # spread across partitions; ensures every partition has data + +# Per-page cap applied to every multi-overlap query in this module. +# Small enough to drive several pages per partition under the seeded data +# count, so any per-page over-fetch or duplicate-on-resume shows up across +# the page sequence rather than only on the last page. +PAGE_SIZE = 5 + +# Per-overlap data threshold below which we skip a configuration as not a +# meaningful repro. Need enough docs in each partition to drive ≥ 3 pages +# under PAGE_SIZE = 5. +MIN_DOCS_PER_PARTITION = 15 + + +def _client() -> CosmosClient: + return CosmosClient(HOST, KEY) + + +def _get_container(): + db = _client().get_database_client(DATABASE_ID) + return db.get_container_client(REPRO_CONTAINER_ID) + + +def _sorted_partition_ranges(container) -> List[Tuple[str, str]]: + """Return current physical partitions' EPK ranges as (min, max) tuples, + sorted by ``min``. Reads the routing map via ``read_feed_ranges()`` (the + public surface that returns one dict per current physical partition). + """ + feed_ranges = list(container.read_feed_ranges()) + pairs: List[Tuple[str, str]] = [] + for fr in feed_ranges: + r = fr["Range"] + pairs.append((r["min"], r["max"])) + pairs.sort(key=lambda p: p[0]) + return pairs + + +def _count_in_range(container, range_min: str, range_max: str) -> int: + fr = test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + items = list(container.query_items( + query="SELECT VALUE COUNT(1) FROM c", feed_range=fr)) + return items[0] if items else 0 + + +def _crossing_feed_range(range_min: str, range_max: str): + """Synthesize a feed_range whose ``[min, max)`` interval spans the union + of one or more current physical partitions — the shape a feed_range + takes after the underlying partition has been split.""" + return test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + + +def _ids_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + """Ground-truth set of document ids inside the union of the given physical + partition ranges. Each partition is queried independently (each call is a + single-overlap query that does NOT exercise the multi-overlap fan-out + branch), so this is the correct baseline to compare a crossing-feed_range + query against.""" + ground_truth = set() + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + for item in container.query_items(query="SELECT c.id FROM c", feed_range=fr): + ground_truth.add(item["id"]) + return ground_truth + + +def _drain_pages(pager) -> Tuple[List[List[dict]], List[str]]: + """Iterate ``pager`` to exhaustion. Return the per-page item lists (so the + caller can assert on per-page sizes) and the ordered list of all ids + encountered (so the caller can assert on duplicates).""" + pages: List[List[dict]] = [] + all_ids: List[str] = [] + for page in pager: + items = list(page) + pages.append(items) + all_ids.extend(item["id"] for item in items) + return pages, all_ids + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(): + """Create a dedicated container for these tests, populate it with enough + documents that every physical partition has data on both sides of every + internal boundary, and tear it down afterwards.""" + client = _client() + db = client.get_database_client(DATABASE_ID) + container = db.create_container_if_not_exists( + id=REPRO_CONTAINER_ID, + partition_key=PartitionKey(path="/" + REPRO_PARTITION_KEY, kind="Hash"), + offer_throughput=REPRO_THROUGHPUT) + # Insert REPRO_DOC_COUNT documents with distinct partition-key values. + # SHA-based PK hashing distributes these roughly uniformly across the + # container's physical partitions, so each partition ends up with a few + # dozen documents — enough to drive multiple pages at PAGE_SIZE=5. + for i in range(REPRO_DOC_COUNT): + container.upsert_item({ + REPRO_PARTITION_KEY: f"pk-{i:04d}", + "id": f"doc-{i:04d}", + "value": i, + }) + yield + try: + db.delete_container(REPRO_CONTAINER_ID) + except Exception: # pylint: disable=broad-except + pass + + +@pytest.mark.cosmosQuery +class TestFeedRangeMultiPartition: + """Sync end-to-end tests for feed_range queries that overlap multiple + physical partitions.""" + + # ------------------------------------------------------------------ # + # Single-partition control (regression guard for the no-fan-out path) + # ------------------------------------------------------------------ # + def test_single_partition_feed_range(self): + """``feed_range`` strictly inside one physical partition's EPK + range, ``max_item_count=PAGE_SIZE``: every page must contain + exactly ``PAGE_SIZE`` items (except possibly the last one), no + duplicates across pages, and the last page's continuation must be + ``None``. + + This is the path the vast majority of feedRanges follow. It does + NOT exercise the multi-overlap fan-out branch; a regression here + means the single-overlap path itself is broken. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if not partitions: + pytest.skip("Container has no physical partitions") + + # Pick the first partition that holds enough docs to drive + # multiple PAGE_SIZE pages. + chosen_pp = None + for (mn, mx) in partitions: + if _count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION: + chosen_pp = (mn, mx) + break + if chosen_pp is None: + pytest.skip("No single partition populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + single = _crossing_feed_range(chosen_pp[0], chosen_pp[1]) + ground_truth = _ids_via_per_partition_scan(container, [chosen_pp]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=single, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + # Every page except the last one is exactly PAGE_SIZE; the last + # page is at most PAGE_SIZE. + for idx, page in enumerate(pages): + if idx < len(pages) - 1: + assert len(page) == PAGE_SIZE, ( + f"page {idx} returned {len(page)} items, expected " + f"exactly {PAGE_SIZE} (only the last page is allowed " + "to be short on the single-overlap path)") + else: + assert len(page) <= PAGE_SIZE + + # No duplicates and full coverage of the partition. + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"single-partition path returned duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"single-partition coverage mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + # After the last page, the continuation token must be None + # (composite drained -> caller's loop terminates correctly). + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages; got " + f"{pager.continuation_token!r}") + + # ------------------------------------------------------------------ # + # Partition-key caller shapes (full key and prefix key) + # ------------------------------------------------------------------ # + def test_full_partition_key_query_pagination_resume(self): + """Full hierarchical partition-key query resumes correctly by continuation. + + This uses a dedicated MultiHash container and a full key value so the + query stays scoped to one logical partition while still exercising + pagination + resume on the partition_key path. + """ + db = _client().get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionFullPK-" + str(uuid.uuid4()) + created_container = db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + full_key = ['CA', 'Oxnard', '93033'] + for i in range(25): + created_container.upsert_item({ + 'id': f'full-pk-doc-{i:03d}', + 'state': full_key[0], + 'city': full_key[1], + 'zipcode': full_key[2], + 'value': i, + }) + for i in range(5): + created_container.upsert_item({ + 'id': f'other-doc-{i:03d}', + 'state': 'WA', + 'city': 'Seattle', + 'zipcode': f'98{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=full_key, + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page = list(next(pager)) + assert first_page + token = pager.continuation_token + assert token + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resumed_remaining_ids = [] + for page in query_iterable.by_page(token): + resumed_remaining_ids.extend(item['id'] for item in page) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] for item in created_container.query_items(query=query, partition_key=full_key) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + finally: + db.delete_container(created_container.id) + + def test_prefix_partition_key_query_pagination_resume(self): + """Prefix hierarchical partition-key query resumes correctly by continuation. + + The caller provides only the first level (``['CA']``). The query spans + multiple descendants under that prefix and must preserve continuation + correctness on resume. + """ + db = _client().get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionPrefixPK-" + str(uuid.uuid4()) + created_container = db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey(path=['/state', '/city', '/zipcode'], kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + for i in range(30): + created_container.upsert_item({ + 'id': f'ca-doc-{i:03d}', + 'state': 'CA', + 'city': f'city-{i % 5}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + for i in range(6): + created_container.upsert_item({ + 'id': f'wa-doc-{i:03d}', + 'state': 'WA', + 'city': f'city-{i % 2}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=['CA'], + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page = list(next(pager)) + assert first_page + token = pager.continuation_token + assert token + + expected_remaining_ids = [] + for page in pager: + expected_remaining_ids.extend(item['id'] for item in page) + + resumed_remaining_ids = [] + for page in query_iterable.by_page(token): + resumed_remaining_ids.extend(item['id'] for item in page) + + assert expected_remaining_ids == resumed_remaining_ids + + baseline_ids = [ + item['id'] for item in created_container.query_items(query=query, partition_key=['CA']) + ] + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids + finally: + db.delete_container(created_container.id) + + + # ------------------------------------------------------------------ # + # Two-partition feed_range + # ------------------------------------------------------------------ # + def test_two_partition_feed_range(self): + """Construct a feed_range that overlaps two adjacent physical + partitions and pin three invariants: + + (a) per-page item count ≤ ``max_item_count`` (the fan-out must + not concatenate per-overlap responses into one oversized + logical page), + (b) no duplicate ids across pages (each overlap's outbound + continuation must be preserved so the next page resumes + instead of restarting from offset 0), + (c) the union of ids returned matches the union of ids from + independent per-partition scans (no missing items). + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + # Find the first adjacent pair where both partitions hold enough docs + # to drive ≥ 3 pages under PAGE_SIZE = 5. + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + # (a) every page must respect max_item_count + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + # (b) no duplicates across pages + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + # (c) no missing items vs ground truth + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + def test_two_partition_feed_range_count_aggregate_pagination(self): + """Run a VALUE aggregate through a two-partition crossing feed_range. + + Guards aggregate-specific invariants on the multi-overlap path: + (a) each logical page still respects ``max_item_count``, + (b) partial aggregate fragments are merged client-side (one scalar + result after draining), + (c) merged count matches an independent per-partition scan baseline. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Ground truth from independent per-partition scans, not aggregate path. + expected_count = len(_ids_via_per_partition_scan(container, [chosen[0], chosen[1]])) + + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + for page in pager: + items = list(page) + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_count, ( + "merged COUNT result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_count}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + + @pytest.mark.parametrize("merge_error_type", [TypeError, KeyError]) + def test_two_partition_feed_range_merge_fallback_preserves_rows( + self, monkeypatch, caplog, merge_error_type + ): + """Force merge failures and verify fallback extends docs without loss.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + merge_call_count = 0 + + def _raising_merge(*_args, **_kwargs): + nonlocal merge_call_count + merge_call_count += 1 + raise merge_error_type("injected-merge-failure") + + monkeypatch.setattr(_base, "_merge_query_results", _raising_merge) + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + assert merge_call_count > 0 + assert any( + "Falling back to non-aggregate merge after aggregate merge failure" in record.getMessage() + for record in caplog.records + ), "Expected warning log for merge fallback path" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"fallback path produced duplicates: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + assert unique == ground_truth, ( + f"fallback path dropped/added items: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + def test_exception_during_post_preserves_resume_checkpoint(self): + """Inject a POST failure mid-query and verify the call site stamps + an outbound continuation that resumes from the last successful slice. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + client_conn = container.client_connection + original_post = client_conn._CosmosClientConnection__Post + call_count = 0 + + def _failing_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("injected-post-failure") + return original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + with pytest.raises(RuntimeError, match="injected-post-failure"): + _ = list(next(pager)) + finally: + client_conn._CosmosClientConnection__Post = original_post + + token = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert token, "Expected continuation checkpoint to be stamped on POST failure" + decoded = _decode_token(token) + assert decoded is not None + assert decoded["c"][0]["min"] == chosen[1][0], ( + "Checkpoint should resume from the second sub-range after the first " + "sub-range completed successfully before failure." + ) + + def test_explode_iteration_guard_raises_in_query_loop(self, monkeypatch): + """Drive the live ``__QueryFeed`` explode loop until the runtime guard raises.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + # Force every routing lookup to look like an unresolved post-split overlap. + def _always_multi_overlap(_rid, feed_ranges, _opts): + head = feed_ranges[0] + return [ + {"id": "left", "minInclusive": head.min, "maxExclusive": head.max}, + {"id": "right", "minInclusive": head.min, "maxExclusive": head.max}, + ] + + monkeypatch.setattr( + client_conn._routing_map_provider, "get_overlapping_ranges", _always_multi_overlap + ) + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation._MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS", 2 + ) + + with pytest.raises(RuntimeError) as excinfo: + list( + container.query_items( + query="SELECT * FROM c", feed_range=crossing, max_item_count=PAGE_SIZE + ).by_page() + ) + assert "split re-resolution" in str(excinfo.value) + + def test_no_progress_guard_logs_warning_in_query_loop(self, monkeypatch, caplog): + """Drive repeated empty pages with unchanged continuation and assert warning emission.""" + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + post_call_count = 0 + + def _stalled_post(*_args, **_kwargs): + nonlocal post_call_count + post_call_count += 1 + continuation = "stalled-token" if post_call_count <= 3 else None + return {"Documents": []}, {http_constants.HttpHeaders.Continuation: continuation} + + monkeypatch.setattr(client_conn, "_CosmosClientConnection__Post", _stalled_post) + monkeypatch.setattr( + "azure.cosmos._cosmos_client_connection._MAX_CONSECUTIVE_NO_PROGRESS_PAGES", 2 + ) + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + list( + container.query_items( + query="SELECT * FROM c", feed_range=crossing, max_item_count=PAGE_SIZE + ).by_page() + ) + + assert post_call_count >= 3 + assert any( + "same continuation token" in record.getMessage() for record in caplog.records + ), "Expected warning log from no-progress guard" + + # ------------------------------------------------------------------ # + # Three-way overlap (synthetic, wider fan-out) + # ------------------------------------------------------------------ # + def test_three_way_overlap(self): + """Same shape as ``test_two_partition_feed_range`` but with a + ``feed_range`` that overlaps **three** adjacent physical partitions. + Wider fan-out exercises the same three guarantees as the + two-partition test on a larger overlap set. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 3: + pytest.skip("Need a container with ≥ 3 physical partitions") + + chosen: Optional[List[Tuple[str, str]]] = None + for i in range(len(partitions) - 2): + triple = partitions[i:i + 3] + if all(_count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION + for mn, mx in triple): + chosen = triple + break + if chosen is None: + pytest.skip("No three adjacent partitions all populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + assert chosen is not None # narrows for type checkers; pytest.skip raises + crossing = _crossing_feed_range(chosen[0][0], chosen[2][1]) + ground_truth = _ids_via_per_partition_scan(container, chosen) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated; sizes={[len(p) for p in pages]}; " + f"oversized={oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + # ------------------------------------------------------------------ # + # Post-split resume (slow; requires a real partition split) + # ------------------------------------------------------------------ # + @pytest.mark.cosmosSplit + def test_post_split_resume(self): + """End-to-end "the routing layout changed underneath a saved + continuation token" scenario: + + 1. Construct a 2-overlap crossing feed_range under the *current* + routing map; drain page 1 and save the continuation token. + 2. Trigger a real partition split (``trigger_split``) so the + container's physical partition count grows. The same EPK + ``{min, max}`` interval now overlaps a different (≥ 2) set of + physical partitions. + 3. Resume with the saved continuation token + the same + ``feed_range``. Drain remaining pages. + 4. Assert: combined ids across page 1 + post-split pages are + unique and equal the union of a fresh per-partition scan over + the same EPK interval. + + On the post-split resume path, the saved continuation must remain + valid (or be safely restarted) under the new physical layout - the + combined ids across the split boundary must still be unique and + cover the same EPK interval. + """ + container = _get_container() + partitions_before = _sorted_partition_ranges(container) + if len(partitions_before) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions_before) - 1): + p0, p1 = partitions_before[i], partitions_before[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 only and save the outbound continuation. + pager_pre = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1 = list(next(pager_pre)) + ct_after_page_1 = pager_pre.continuation_token + page_1_ids = [item["id"] for item in page_1] + assert ct_after_page_1, ( + "expected a non-empty continuation token after page 1; the " + "feed_range overlaps two partitions and the first page should " + "not have drained the whole interval") + + # Step 2 — trigger a real split. This is the slow step (up to 10 min + # for the offer-replace operation to complete). + target_throughput = max(REPRO_THROUGHPUT * 2, 60000) + try: + test_config.TestConfig.trigger_split(container, target_throughput) + except unittest.SkipTest: + raise + # Allow the routing map a brief settling period after the split + # completes, then force a refresh so the SDK sees the new layout. + time.sleep(10) + list(container.read_feed_ranges(force_refresh=True)) + + # Step 3 — resume with the saved continuation, same feed_range. + pager_post = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=ct_after_page_1) + post_pages, post_ids = _drain_pages(pager_post) + + combined_ids = page_1_ids + post_ids + unique = set(combined_ids) + duplicate_count = len(combined_ids) - len(unique) + # When the parent's backend continuation is dropped during the + # post-split explode (children won't accept the parent's bc), + # the children restart at offset 0 of their slice. The lower + # child can therefore re-emit up to PAGE_SIZE rows that page 1 + # already returned. The strict no-dup invariant only holds when + # page 1 happened to fully drain the parent slice; the + # bounded-replay invariant always holds and is what we assert + # here. The strict no-loss / no-out-of-range guarantee is still + # enforced by the ``unique == ground_truth`` check below. + assert duplicate_count <= PAGE_SIZE, ( + f"unexpected duplicate count across the split boundary: " + f"{len(combined_ids)} ids returned across page 1 + " + f"{len(post_pages)} post-split page(s), {len(unique)} distinct, " + f"{duplicate_count} duplicate(s) (max allowed: {PAGE_SIZE}).") + + oversized = [(i, len(p)) for i, p in enumerate(post_pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated post-split; sizes=" + f"{[len(p) for p in post_pages]}; oversized={oversized}.") + + # Re-derive ground truth against the post-split routing map. + partitions_after = _sorted_partition_ranges(container) + post_split_overlaps = [(mn, mx) for (mn, mx) in partitions_after + if min(p1_max, mx) > max(p0_min, mn)] + ground_truth = _ids_via_per_partition_scan(container, post_split_overlaps) + assert unique == ground_truth, ( + f"item-set mismatch after post-split resume: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + + # ------------------------------------------------------------------ # + # Legacy opaque token compatibility + # ------------------------------------------------------------------ # + def test_legacy_opaque_token_compat(self, caplog): + """Use an opaque continuation token (not base64 JSON, not v=1). + The query restarts from the beginning. + + Asserts: + (a) no exception is raised on the resume call, + (b) all batches restart from the beginning (the union of ids + returned equals the per-partition ground truth), + (c) every page respects ``max_item_count``, + (d) pagination runs to completion (final continuation is None). + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + # Opaque continuation string that does not match the structured + # token format. + # cspell:ignore AOXB BAAAAAAAAAA EAAAAFAAAA + legacy_token = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567#FPC:AgEAAAAFAAAA" + + with caplog.at_level("WARNING", logger="azure.cosmos._cosmos_client_connection"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=legacy_token) + + # (a) no exception on iteration + pages, all_ids = _drain_pages(pager) + + assert any( + "not in the supported structured format" in record.getMessage() + for record in caplog.records + ), "Expected warning log when a non-structured continuation token is supplied" + assert all( + legacy_token not in record.getMessage() for record in caplog.records + ), "Warning log must not include raw continuation token text" + + # (c) page-size limit respected + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated under legacy-token resume; " + f"sizes={[len(p) for p in pages]}; oversized={oversized}") + + # (b) full restart from offset 0 -> coverage matches ground truth, + # no duplicates + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"legacy-token restart produced duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"legacy-token restart coverage mismatch: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + # (d) pagination ran to completion + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages on " + f"legacy-token restart; got {pager.continuation_token!r}") + + # ------------------------------------------------------------------ # + # Identity-fingerprint mismatch rejection (live half) + # ------------------------------------------------------------------ # + def test_token_identity_mismatch_rejected(self): + """Round-trip a token through ``query_items`` then replay it + against (a) a different query text, (b) a different parameter + value, and (c) a different ``feed_range``. Each replay must raise + ``ValueError`` from the call-site validation in ``__QueryFeed`` + with a message that names the failing field. + + The unit tests in ``test_feed_range_continuation_token.py`` cover + the hash-inequality contract; this test covers the live raise + path through the SDK's actual query pipeline. + """ + container = _get_container() + partitions = _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (_count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 of a parameterized query and save the + # outbound continuation. The token's qh/frh fingerprints encode + # this query + this feed_range. + # Use bracket notation: ``value`` is a reserved word in Cosmos SQL. + original_query = "SELECT * FROM c WHERE c[\"value\"] >= @v" + original_params = [{"name": "@v", "value": 0}] + pager = container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + _ = list(next(pager)) + token = pager.continuation_token + assert token, ("expected a non-empty continuation after page 1; " + "the test cannot exercise resume validation otherwise") + + # (a) Different query TEXT — qh mismatch. + with pytest.raises(ValueError) as excinfo_q: + list(container.query_items( + query={"query": "SELECT * FROM c WHERE c[\"value\"] >= @v AND c.id != ''", + "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + msg_q = str(excinfo_q.value).lower() + assert "query" in msg_q, ( + "ValueError on query-text mismatch must name the failing " + f"field; got: {excinfo_q.value!r}") + + # (b) Different parameter VALUE — same query text, different qh. + with pytest.raises(ValueError) as excinfo_p: + list(container.query_items( + query={"query": original_query, + "parameters": [{"name": "@v", "value": 999999}]}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_p.value).lower(), ( + "ValueError on parameter-value mismatch must name the query " + f"field; got: {excinfo_p.value!r}") + + # (c) Different feed_range — frh mismatch. Use a single-partition + # sub-range of the original crossing range (still inside the same + # collection so cr matches; only frh differs). + single_p0 = _crossing_feed_range(chosen[0][0], chosen[0][1]) + with pytest.raises(ValueError) as excinfo_f: + list(container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=single_p0, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "feed_range" in str(excinfo_f.value).lower(), ( + "ValueError on feed_range mismatch must name the feed_range " + f"field; got: {excinfo_f.value!r}") diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py new file mode 100644 index 000000000000..83ffc44fe06c --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py @@ -0,0 +1,811 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Async multi-partition feed_range tests.""" + +import asyncio +import unittest +import uuid +from typing import Iterable, List, Optional, Tuple + +import pytest +import pytest_asyncio + +import test_config +from azure.cosmos import _base +from azure.cosmos import http_constants +from azure.cosmos.aio import CosmosClient +from azure.cosmos._routing.feed_range_continuation import _decode_token +from azure.cosmos.partition_key import PartitionKey + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID + +REPRO_CONTAINER_ID = "FeedRangeMultiPartitionAsync-" + str(uuid.uuid4()) +REPRO_PARTITION_KEY = "pk" +REPRO_THROUGHPUT = CONFIG.THROUGHPUT_FOR_5_PARTITIONS +REPRO_DOC_COUNT = 200 +PAGE_SIZE = 5 +MIN_DOCS_PER_PARTITION = 15 + + +def _client() -> CosmosClient: + return CosmosClient(HOST, KEY) + + +def _get_container(client: CosmosClient): + return client.get_database_client(DATABASE_ID).get_container_client(REPRO_CONTAINER_ID) + + +async def _sorted_partition_ranges(container) -> List[Tuple[str, str]]: + pairs: List[Tuple[str, str]] = [] + async for fr in container.read_feed_ranges(): + r = fr["Range"] + pairs.append((r["min"], r["max"])) + pairs.sort(key=lambda p: p[0]) + return pairs + + +async def _count_in_range(container, range_min: str, range_max: str) -> int: + fr = test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + items = [it async for it in container.query_items( + query="SELECT VALUE COUNT(1) FROM c", feed_range=fr)] + return items[0] if items else 0 + + +def _crossing_feed_range(range_min: str, range_max: str): + return test_config.create_feed_range_in_dict( + test_config.create_range(range_min=range_min, range_max=range_max, + is_min_inclusive=True, is_max_inclusive=False)) + + +async def _ids_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + ground_truth = set() + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + async for item in container.query_items(query="SELECT c.id FROM c", feed_range=fr): + ground_truth.add(item["id"]) + return ground_truth + + +async def _values_via_per_partition_scan(container, partition_ranges: Iterable[Tuple[str, str]]): + values = [] + for (mn, mx) in partition_ranges: + fr = _crossing_feed_range(mn, mx) + async for value in container.query_items(query='SELECT VALUE c["value"] FROM c', feed_range=fr): + values.append(value) + return values + + +async def _drain_pages(pager) -> Tuple[List[List[dict]], List[str]]: + pages: List[List[dict]] = [] + all_ids: List[str] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + all_ids.extend(it["id"] for it in items) + return pages, all_ids + + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown_async(): + client = _client() + db = client.get_database_client(DATABASE_ID) + container = await db.create_container_if_not_exists( + id=REPRO_CONTAINER_ID, + partition_key=PartitionKey(path="/" + REPRO_PARTITION_KEY, kind="Hash"), + offer_throughput=REPRO_THROUGHPUT) + for i in range(REPRO_DOC_COUNT): + await container.upsert_item({ + REPRO_PARTITION_KEY: f"pk-{i:04d}", + "id": f"doc-{i:04d}", + "value": i, + }) + yield + try: + await db.delete_container(REPRO_CONTAINER_ID) + except Exception: # pylint: disable=broad-except + pass + await client.close() + + +@pytest.mark.cosmosQuery +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown_async") +class TestFeedRangeMultiPartitionAsync: + """Async end-to-end tests for feed_range queries that overlap multiple + physical partitions.""" + + # ------------------------------------------------------------------ # + # Single-partition control + # ------------------------------------------------------------------ # + async def test_single_partition_feed_range_async(self): + """Single-partition regression guard.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if not partitions: + pytest.skip("Container has no physical partitions") + + chosen_pp = None + for (mn, mx) in partitions: + if await _count_in_range(container, mn, mx) >= MIN_DOCS_PER_PARTITION: + chosen_pp = (mn, mx) + break + if chosen_pp is None: + pytest.skip("No single partition populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + single = _crossing_feed_range(chosen_pp[0], chosen_pp[1]) + ground_truth = await _ids_via_per_partition_scan(container, [chosen_pp]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=single, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + for idx, page in enumerate(pages): + if idx < len(pages) - 1: + assert len(page) == PAGE_SIZE, ( + f"page {idx} returned {len(page)} items, expected " + f"exactly {PAGE_SIZE}") + else: + assert len(page) <= PAGE_SIZE + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"single-partition path returned duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"single-partition coverage mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + + # ------------------------------------------------------------------ # + # Two-partition feed_range + # ------------------------------------------------------------------ # + async def test_two_partition_feed_range_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages: {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + async def test_two_partition_feed_range_count_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_count = len(await _ids_via_per_partition_scan(container, [chosen[0], chosen[1]])) + + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_count, ( + "merged COUNT result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_count}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + @pytest.mark.parametrize("merge_error_type", [TypeError, KeyError]) + async def test_two_partition_feed_range_merge_fallback_preserves_rows_async( + self, monkeypatch, caplog, merge_error_type + ): + """Force merge failures and verify fallback extends docs without loss.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with >= " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan(container, [chosen[0], chosen[1]]) + + merge_call_count = 0 + + def _raising_merge(*_args, **_kwargs): + nonlocal merge_call_count + merge_call_count += 1 + raise merge_error_type("injected-merge-failure") + + monkeypatch.setattr(_base, "_merge_query_results", _raising_merge) + + with caplog.at_level("WARNING", logger="azure.cosmos.aio._cosmos_client_connection_async"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + assert merge_call_count > 0 + assert any( + "Falling back to non-aggregate merge after aggregate merge failure" in record.getMessage() + for record in caplog.records + ), "Expected warning log for merge fallback path" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated (max_item_count={PAGE_SIZE}); " + f"got pages with sizes {[len(p) for p in pages]}; " + f"oversized pages (index, size): {oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"fallback path produced duplicates: {len(all_ids)} items returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + assert unique == ground_truth, ( + f"fallback path dropped/added items: returned={len(unique)} ids, " + f"ground_truth={len(ground_truth)} ids, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + async def test_two_partition_feed_range_min_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_values = await _values_via_per_partition_scan(container, [chosen[0], chosen[1]]) + expected_min = min(expected_values) + + pager = container.query_items( + query='SELECT VALUE MIN(c["value"]) FROM c', + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_min, ( + "merged MIN result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_min}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + async def test_two_partition_feed_range_max_aggregate_pagination_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + expected_values = await _values_via_per_partition_scan(container, [chosen[0], chosen[1]]) + expected_max = max(expected_values) + + pager = container.query_items( + query='SELECT VALUE MAX(c["value"]) FROM c', + feed_range=crossing, + max_item_count=1, + ).by_page() + + pages: List[List[object]] = [] + merged_rows: List[object] = [] + async for page in pager: + items = [it async for it in page] + pages.append(items) + merged_rows.extend(items) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > 1] + assert not oversized, ( + "aggregate page-size limit violated (max_item_count=1); " + f"page sizes={[len(p) for p in pages]}, oversized={oversized}.") + assert len(merged_rows) == 1, ( + "aggregate merge leaked partial fragments or dropped final value; " + f"expected one merged row, got {len(merged_rows)} rows: {merged_rows}") + assert merged_rows[0] == expected_max, ( + "merged MAX result mismatch for two-partition crossing feed_range; " + f"returned={merged_rows[0]}, expected={expected_max}") + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining aggregate query; got " + f"{pager.continuation_token!r}") + finally: + await client.close() + + async def test_exception_during_post_preserves_resume_checkpoint_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + client_conn = container.client_connection + original_post = client_conn._CosmosClientConnection__Post + call_count = 0 + + async def _failing_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("injected-post-failure") + return await original_post(*args, **kwargs) + + client_conn._CosmosClientConnection__Post = _failing_post + try: + pager = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + feed_range=crossing, + max_item_count=1, + ).by_page() + with pytest.raises(RuntimeError, match="injected-post-failure"): + page_iter = await pager.__anext__() + _ = [it async for it in page_iter] + finally: + client_conn._CosmosClientConnection__Post = original_post + + token = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert token, "Expected continuation checkpoint to be stamped on POST failure" + decoded = _decode_token(token) + assert decoded is not None + assert decoded["c"][0]["min"] == chosen[1][0], ( + "Checkpoint should resume from the second sub-range after the first " + "sub-range completed successfully before failure." + ) + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Three-way overlap + # ------------------------------------------------------------------ # + async def test_three_way_overlap_async(self): + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 3: + pytest.skip("Need a container with ≥ 3 physical partitions") + + chosen: Optional[List[Tuple[str, str]]] = None + for i in range(len(partitions) - 2): + triple = partitions[i:i + 3] + ok = True + for mn, mx in triple: + if await _count_in_range(container, mn, mx) < MIN_DOCS_PER_PARTITION: + ok = False + break + if ok: + chosen = triple + break + if chosen is None: + pytest.skip("No three adjacent partitions all populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + assert chosen is not None # narrows for type checkers; pytest.skip raises + crossing = _crossing_feed_range(chosen[0][0], chosen[2][1]) + ground_truth = await _ids_via_per_partition_scan(container, chosen) + + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + pages, all_ids = await _drain_pages(pager) + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated; sizes={[len(p) for p in pages]}; " + f"oversized={oversized}.") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"duplicates across pages: {len(all_ids)} returned, " + f"{len(unique)} distinct, " + f"{len(all_ids) - len(unique)} duplicate(s).") + + assert unique == ground_truth, ( + f"item-set mismatch: returned={len(unique)}, " + f"ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Post-split resume (slow) + # ------------------------------------------------------------------ # + @pytest.mark.cosmosSplit + async def test_post_split_resume_async(self): + client = _client() + try: + container = _get_container(client) + partitions_before = await _sorted_partition_ranges(container) + if len(partitions_before) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions_before) - 1): + p0, p1 = partitions_before[i], partitions_before[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Step 1 — drain page 1 only and save the outbound continuation. + pager_pre = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1_iter = await pager_pre.__anext__() + page_1 = [it async for it in page_1_iter] + ct_after_page_1 = pager_pre.continuation_token + page_1_ids = [item["id"] for item in page_1] + assert ct_after_page_1, ( + "expected a non-empty continuation token after page 1") + + # Step 2 — trigger a real split. + target_throughput = max(REPRO_THROUGHPUT * 2, 60000) + try: + await test_config.TestConfig.trigger_split_async(container, target_throughput) + except unittest.SkipTest: + raise + await asyncio.sleep(10) + _ = [fr async for fr in container.read_feed_ranges(force_refresh=True)] + + # Step 3 — resume with the saved continuation, same feed_range. + pager_post = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=ct_after_page_1) + post_pages, post_ids = await _drain_pages(pager_post) + + combined_ids = page_1_ids + post_ids + unique = set(combined_ids) + duplicate_count = len(combined_ids) - len(unique) + # When the parent's backend continuation is dropped during + # the post-split explode (children won't accept the + # parent's bc), the children restart at offset 0 of their + # slice. The lower child can therefore re-emit up to + # PAGE_SIZE rows that page 1 already returned. The strict + # no-dup invariant only holds when page 1 happened to fully + # drain the parent slice; the bounded-replay invariant + # always holds and is what we assert here. The strict + # no-loss / no-out-of-range guarantee is still enforced by + # the ``unique == ground_truth`` check below. + assert duplicate_count <= PAGE_SIZE, ( + f"unexpected duplicate count across the split boundary: " + f"{len(combined_ids)} ids returned across page 1 + " + f"{len(post_pages)} post-split page(s), {len(unique)} distinct, " + f"{duplicate_count} duplicate(s) (max allowed: {PAGE_SIZE}).") + + oversized = [(i, len(p)) for i, p in enumerate(post_pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated post-split; sizes=" + f"{[len(p) for p in post_pages]}; oversized={oversized}.") + + partitions_after = await _sorted_partition_ranges(container) + post_split_overlaps = [(mn, mx) for (mn, mx) in partitions_after + if min(p1_max, mx) > max(p0_min, mn)] + ground_truth = await _ids_via_per_partition_scan(container, post_split_overlaps) + assert unique == ground_truth, ( + f"item-set mismatch after post-split resume: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}.") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Legacy opaque token compatibility + # ------------------------------------------------------------------ # + async def test_legacy_opaque_token_compat_async(self, caplog): + """Use an opaque continuation token and verify restart behavior.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + ground_truth = await _ids_via_per_partition_scan( + container, [chosen[0], chosen[1]]) + + # cspell:ignore AOXB BAAAAAAAAAA EAAAAFAAAA + legacy_token = "+RID:~Yxs1AOXBSp4BAAAAAAAAAA==#RT:1#TRC:5#ISV:2#IEO:65567#FPC:AgEAAAAFAAAA" + + with caplog.at_level("WARNING", logger="azure.cosmos.aio._cosmos_client_connection_async"): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=legacy_token) + pages, all_ids = await _drain_pages(pager) + + assert any( + "not in the supported structured format" in record.getMessage() + for record in caplog.records + ), "Expected warning log when a non-structured continuation token is supplied" + assert all( + legacy_token not in record.getMessage() for record in caplog.records + ), "Warning log must not include raw continuation token text" + + oversized = [(i, len(p)) for i, p in enumerate(pages) if len(p) > PAGE_SIZE] + assert not oversized, ( + f"page-size limit violated under legacy-token resume; " + f"sizes={[len(p) for p in pages]}; oversized={oversized}") + + unique = set(all_ids) + assert len(all_ids) == len(unique), ( + f"legacy-token restart produced duplicates: " + f"{len(all_ids) - len(unique)} duplicate id(s)") + assert unique == ground_truth, ( + f"legacy-token restart coverage mismatch: returned=" + f"{len(unique)}, ground_truth={len(ground_truth)}, " + f"missing={len(ground_truth - unique)}, " + f"unexpected={len(unique - ground_truth)}") + + assert pager.continuation_token in (None, "", b""), ( + f"expected empty continuation after draining all pages on " + f"legacy-token restart; got {pager.continuation_token!r}") + finally: + await client.close() + + # ------------------------------------------------------------------ # + # Identity-fingerprint mismatch rejection (live half) + # ------------------------------------------------------------------ # + async def test_token_identity_mismatch_rejected_async(self): + """Live identity-mismatch rejection test.""" + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with ≥ 2 physical partitions") + + chosen = None + for i in range(len(partitions) - 1): + p0, p1 = partitions[i], partitions[i + 1] + if (await _count_in_range(container, p0[0], p0[1]) >= MIN_DOCS_PER_PARTITION + and await _count_in_range(container, p1[0], p1[1]) >= MIN_DOCS_PER_PARTITION): + chosen = (p0, p1) + break + if chosen is None: + pytest.skip("No adjacent partition pair both populated with ≥ " + f"{MIN_DOCS_PER_PARTITION} docs") + (p0_min, _), (_, p1_max) = chosen + crossing = _crossing_feed_range(p0_min, p1_max) + + # Use bracket notation: ``value`` is a reserved word in Cosmos SQL. + original_query = "SELECT * FROM c WHERE c[\"value\"] >= @v" + original_params = [{"name": "@v", "value": 0}] + pager = container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + page_1_iter = await pager.__anext__() + _ = [it async for it in page_1_iter] + token = pager.continuation_token + assert token, ("expected a non-empty continuation after page 1; " + "test cannot exercise resume validation otherwise") + + async def _drain(p): + async for page in p: + _ = [it async for it in page] + + # (a) Different query TEXT. + with pytest.raises(ValueError) as excinfo_q: + await _drain(container.query_items( + query={"query": "SELECT * FROM c WHERE c[\"value\"] >= @v AND c.id != ''", + "parameters": original_params}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_q.value).lower(), ( + f"ValueError on query-text mismatch must name the failing " + f"field; got: {excinfo_q.value!r}") + + # (b) Different parameter VALUE. + with pytest.raises(ValueError) as excinfo_p: + await _drain(container.query_items( + query={"query": original_query, + "parameters": [{"name": "@v", "value": 999999}]}, + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "query" in str(excinfo_p.value).lower(), ( + f"ValueError on parameter-value mismatch must name the " + f"query field; got: {excinfo_p.value!r}") + + # (c) Different feed_range. + single_p0 = _crossing_feed_range(chosen[0][0], chosen[0][1]) + with pytest.raises(ValueError) as excinfo_f: + await _drain(container.query_items( + query={"query": original_query, "parameters": original_params}, + feed_range=single_p0, + max_item_count=PAGE_SIZE, + ).by_page(continuation_token=token)) + assert "feed_range" in str(excinfo_f.value).lower(), ( + f"ValueError on feed_range mismatch must name the " + f"feed_range field; got: {excinfo_f.value!r}") + finally: + await client.close() + + +if __name__ == "__main__": + unittest.main() + From 42292edcbea0638697b794228c7bf0876973cdbe Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 4 May 2026 14:50:24 -0500 Subject: [PATCH 03/18] pipeline fix --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 13 ++- .../azure/cosmos/_cosmos_client_connection.py | 1 + .../azure-cosmos/azure/cosmos/_version.py | 2 +- .../aio/_cosmos_client_connection_async.py | 2 +- .../test_feed_range_continuation_token.py | 6 +- .../tests/test_service_retry_policies.py | 67 ++++++++-------- .../test_service_retry_policies_async.py | 79 ++++++++----------- 8 files changed, 84 insertions(+), 87 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index c456fbeddadd..696b8fcae973 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.14.7 (Unreleased) #### Bugs Fixed +* Fixed bug where `SELECT VALUE AVG(...)` queries spanning multiple physical partitions returned mathematically incorrect merged values from client-side aggregation. These queries now raise `ValueError`. * Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. * Fixed bug where unavailable regional endpoints were dropped from the routing list instead of being kept as fallback options. See [PR 45200](https://github.com/Azure/azure-sdk-for-python/pull/45200) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index fab49b416c2c..27dcea10f30e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -207,7 +207,7 @@ def _merge_query_results( elif aggregate_fn == "MAX": results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] elif aggregate_fn == "AVG": - raise ValueError( + raise _UnsupportedValueAvgMergeError( "VALUE AVG aggregate merge across partitions is not supported client-side." ) else: @@ -236,8 +236,7 @@ def _raise_query_merge_value_error(merge_error: ValueError) -> None: :type merge_error: ValueError :raises ValueError: Always re-raises, potentially with a clearer message. """ - merge_message = str(merge_error) - if "VALUE AVG aggregate merge across partitions is not supported client-side." in merge_message: + if isinstance(merge_error, _UnsupportedValueAvgMergeError): raise ValueError( "Unsupported query shape for range-scoped pagination: " "SELECT VALUE AVG(...) cannot be merged client-side when the query " @@ -246,6 +245,14 @@ def _raise_query_merge_value_error(merge_error: ValueError) -> None: raise merge_error +class _UnsupportedValueAvgMergeError(ValueError): + """Internal marker for unsupported client-side SELECT VALUE AVG(...) merge. + + This type is only used for SDK control flow (type-based handling), then + translated to a clearer public ValueError for callers. + """ + + def GetHeaders( # pylint: disable=too-many-statements,too-many-branches cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], default_headers: Mapping[str, Any], diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 225f32889b1c..715b2dab9934 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3429,6 +3429,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: "Failed to write continuation while handling query POST failure: %s", continuation_write_error, ) + _capture_internal_headers(feedrange_response_headers) raise error # NOTE: Keep this feed_range pagination loop in sync with diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index af83f50bae55..d581f6fc3808 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.14.6" +VERSION = "4.14.7" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 8f75b106573f..a2cb325f0e2a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2103,7 +2103,6 @@ async def _Batch( options.get("partitionKey", None)) request_params.set_excluded_location_from_options(options) await base.set_session_token_header_async(self, headers, path, request_params, options) - request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[list[dict[str, Any]], CaseInsensitiveDict], result) @@ -3221,6 +3220,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: "Failed to write continuation while handling query POST failure: %s", continuation_write_error, ) + _capture_internal_headers(feedrange_response_headers) raise error # NOTE: Keep this feed_range pagination loop in sync with diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index 10aa7b4f2e01..12bf259e1b60 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -1015,13 +1015,15 @@ def test_value_merge_raises_if_aggregate_function_detection_is_missing(self, mon def test_value_avg_merge_raises_as_unsupported(self): query = "SELECT VALUE AVG(c.value) FROM c" - with pytest.raises(ValueError) as excinfo: + with pytest.raises(_base._UnsupportedValueAvgMergeError) as excinfo: _base._merge_query_results({"Documents": [7.0]}, {"Documents": [3.0]}, query) assert "VALUE AVG aggregate merge" in str(excinfo.value) def test_raise_query_merge_value_error_rewrites_value_avg_message(self): - original = ValueError("VALUE AVG aggregate merge across partitions is not supported client-side.") + original = _base._UnsupportedValueAvgMergeError( + "VALUE AVG aggregate merge across partitions is not supported client-side." + ) with pytest.raises(ValueError) as excinfo: _base._raise_query_merge_value_error(original) diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py index 0a4e19049886..f3a208a902dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py @@ -38,6 +38,29 @@ def setUpClass(cls): cls.created_database = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID) + def _setup_read_regions(self, location_cache, regions): + """Set read-region cache state consistently so retry expectations are deterministic.""" + location_cache.account_read_locations = regions + location_cache.account_read_regional_routing_contexts_by_location = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.available_read_regional_endpoints_by_locations = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + location_cache.effective_preferred_locations = regions + + def _setup_write_regions(self, location_cache, regions): + """Set write-region cache state consistently so retry expectations are deterministic.""" + location_cache.account_write_locations = regions + location_cache.account_write_regional_routing_contexts_by_location = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.available_write_regional_endpoints_by_locations = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + def test_service_request_retry_policy(self): mock_client = CosmosClient(self.host, self.masterKey) db = mock_client.get_database_client(self.TEST_DATABASE_ID) @@ -47,14 +70,9 @@ def test_service_request_retry_policy(self): # Save the original function self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions. original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: @@ -72,15 +90,7 @@ def test_service_request_retry_policy(self): # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt # equal to the available read region locations - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: @@ -95,8 +105,7 @@ def test_service_request_retry_policy(self): _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: # Reset the function to reset the counter @@ -110,10 +119,7 @@ def test_service_request_retry_policy(self): _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) expected_counter = len(original_location_cache.write_regional_routing_contexts) try: # Reset the function to reset the counter @@ -136,14 +142,9 @@ def test_service_response_retry_policy(self): # Save the original function self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions. original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) try: # Mock the function to return the ServiceResponseException we retry mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, self.original_execute_function) @@ -156,8 +157,7 @@ def test_service_response_retry_policy(self): _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) try: # Reset the function to reset the counter mf = self.MockExecuteServiceResponseException(Exception) @@ -170,10 +170,7 @@ def test_service_response_retry_policy(self): _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = {self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) try: # Reset the function to reset the counter mf = self.MockExecuteServiceResponseException(Exception) diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py index eb3082fdc976..c210d78898dc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py @@ -10,7 +10,7 @@ from azure.core.exceptions import ServiceRequestError, ServiceResponseError import test_config -from azure.cosmos import DatabaseAccount, _location_cache +from azure.cosmos import DatabaseAccount from azure.cosmos._location_cache import RegionalRoutingContext from azure.cosmos._request_object import RequestObject from azure.cosmos.aio import CosmosClient, _retry_utility_async, _global_endpoint_manager_async @@ -47,6 +47,29 @@ async def asyncTearDown(self): self.connectionPolicy.ConnectionRetryConfiguration = None await self.client.close() + def _setup_read_regions(self, location_cache, regions): + """Set read-region cache state consistently so retry expectations are deterministic.""" + location_cache.account_read_locations = regions + location_cache.account_read_regional_routing_contexts_by_location = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.available_read_regional_endpoints_by_locations = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + location_cache.effective_preferred_locations = regions + + def _setup_write_regions(self, location_cache, regions): + """Set write-region cache state consistently so retry expectations are deterministic.""" + location_cache.account_write_locations = regions + location_cache.account_write_regional_routing_contexts_by_location = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.available_write_regional_endpoints_by_locations = { + region: self.REGIONAL_ENDPOINT for region in regions + } + location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) + async def test_service_request_retry_policy_async(self): # ServiceRequestErrors will always retry, and will retry once per preferred region async with CosmosClient(self.host, self.masterKey) as mock_client: @@ -57,15 +80,9 @@ async def test_service_request_retry_policy_async(self): # Save the original function self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions. original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: # Mock the function to return the ServiceRequestException we retry @@ -82,15 +99,7 @@ async def test_service_request_retry_policy_async(self): # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt # equal to the available read region locations - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: @@ -106,8 +115,7 @@ async def test_service_request_retry_policy_async(self): _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) expected_counter = len(original_location_cache.read_regional_routing_contexts) try: # Reset the function to reset the counter @@ -121,11 +129,7 @@ async def test_service_request_retry_policy_async(self): _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) expected_counter = len(original_location_cache.write_regional_routing_contexts) try: # Reset the function to reset the counter @@ -149,15 +153,9 @@ async def test_service_response_retry_policy_async(self): # Save the original function self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - # Change the location cache to have 3 preferred read regions and 3 available read endpoints by location + # Change the location cache to have 3 preferred read regions. original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) try: # Mock the function to return the ClientConnectionError we retry mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, @@ -171,8 +169,7 @@ async def test_service_response_retry_policy_async(self): _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we change the location cache to have only 1 preferred read region - original_location_cache.account_read_locations = [self.REGION1] - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1]) try: # Reset the function to reset the counter mf = self.MockExecuteServiceResponseException(AttributeError, None) @@ -185,11 +182,7 @@ async def test_service_response_retry_policy_async(self): _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we try it out with a write request - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) try: # Reset the function to reset the counter mf = self.MockExecuteServiceResponseException(AttributeError, None) @@ -205,11 +198,7 @@ async def test_service_response_retry_policy_async(self): # If we do a write request with a ClientConnectionError, # we will do cross-region retries like with read requests - original_location_cache.account_write_locations = [self.REGION1, self.REGION2] - original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) try: # Reset the function to reset the counter mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) From d77b332e8aabff10991cd98b31783c4b911de114 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 4 May 2026 16:53:07 -0500 Subject: [PATCH 04/18] fixing test --- .../test_service_retry_policies_async.py | 265 +++++++++--------- 1 file changed, 134 insertions(+), 131 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py index c210d78898dc..774a3159a7ac 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py @@ -3,6 +3,7 @@ import unittest import uuid +from unittest.mock import AsyncMock, patch import pytest from aiohttp.client_exceptions import (ClientConnectionError, ClientConnectionResetError, @@ -75,72 +76,73 @@ async def test_service_request_retry_policy_async(self): async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) + global_endpoint_manager = mock_client.client_connection._global_endpoint_manager created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) # Save the original function self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - - # Change the location cache to have 3 preferred read regions. - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we test with a query operation, iterating through items sends request without request object - # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt - # equal to the available read region locations - - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - items = [item async for item in container.query_items(query="SELECT * FROM c", - partition_key=created_item['pk'])] - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we change the location cache to have only 1 preferred read region - self._setup_read_regions(original_location_cache, [self.REGION1]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - expected_counter = len(original_location_cache.write_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + with patch.object(global_endpoint_manager, "refresh_endpoint_list", new=AsyncMock(return_value=None)): + # Change the location cache to have 3 preferred read regions. + original_location_cache = global_endpoint_manager.location_cache + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + try: + # Mock the function to return the ServiceRequestException we retry + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) + _retry_utility_async.ExecuteFunctionAsync = mf + await container.read_item(created_item['id'], created_item['pk']) + pytest.fail("Exception was not raised.") + except ServiceRequestError: + assert mf.counter == expected_counter + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we test with a query operation, iterating through items sends request without request object + # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt + # equal to the available read region locations + + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + + try: + # Mock the function to return the ServiceRequestException we retry + mf = self.MockExecuteServiceRequestException() + _retry_utility_async.ExecuteFunctionAsync = mf + items = [item async for item in container.query_items(query="SELECT * FROM c", + partition_key=created_item['pk'])] + pytest.fail("Exception was not raised.") + except ServiceRequestError: + assert mf.counter == expected_counter + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we change the location cache to have only 1 preferred read region + self._setup_read_regions(original_location_cache, [self.REGION1]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceRequestException() + _retry_utility_async.ExecuteFunctionAsync = mf + await container.read_item(created_item['id'], created_item['pk']) + pytest.fail("Exception was not raised.") + except ServiceRequestError: + assert mf.counter == expected_counter + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we try it out with a write request + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + expected_counter = len(original_location_cache.write_regional_routing_contexts) + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceRequestException() + _retry_utility_async.ExecuteFunctionAsync = mf + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) + pytest.fail("Exception was not raised.") + except ServiceRequestError: + assert mf.counter == expected_counter + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_service_response_retry_policy_async(self): # For ServiceResponseErrors, we only do cross region retries on read requests or on ClientConnectionErrors @@ -148,79 +150,80 @@ async def test_service_response_retry_policy_async(self): async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) + global_endpoint_manager = mock_client.client_connection._global_endpoint_manager created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) # Save the original function self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - - # Change the location cache to have 3 preferred read regions. - original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - try: - # Mock the function to return the ClientConnectionError we retry - mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, - None, self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 3 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we change the location cache to have only 1 preferred read region - self._setup_read_regions(original_location_cache, [self.REGION1]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # If we do a write request with a ClientConnectionError, - # we will do cross-region retries like with read requests - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request with retry write enabled - which should retry once - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + with patch.object(global_endpoint_manager, "refresh_endpoint_list", new=AsyncMock(return_value=None)): + # Change the location cache to have 3 preferred read regions. + original_location_cache = global_endpoint_manager.location_cache + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + try: + # Mock the function to return the ClientConnectionError we retry + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, + None, self.original_execute_function) + _retry_utility_async.ExecuteFunctionAsync = mf + await container.read_item(created_item['id'], created_item['pk']) + pytest.fail("Exception was not raised.") + except ServiceResponseError: + assert mf.counter == 3 + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we change the location cache to have only 1 preferred read region + self._setup_read_regions(original_location_cache, [self.REGION1]) + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceResponseException(AttributeError, None) + _retry_utility_async.ExecuteFunctionAsync = mf + await container.read_item(created_item['id'], created_item['pk']) + pytest.fail("Exception was not raised.") + except ServiceResponseError: + assert mf.counter == 1 + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we try it out with a write request + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceResponseException(AttributeError, None) + _retry_utility_async.ExecuteFunctionAsync = mf + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) + pytest.fail("Exception was not raised.") + except ServiceResponseError: + assert mf.counter == 1 + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # If we do a write request with a ClientConnectionError, + # we will do cross-region retries like with read requests + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + _retry_utility_async.ExecuteFunctionAsync = mf + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) + pytest.fail("Exception was not raised.") + except ServiceResponseError: + assert mf.counter == 2 + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + # Now we try it out with a write request with retry write enabled - which should retry once + try: + # Reset the function to reset the counter + mf = self.MockExecuteServiceResponseException(AttributeError, None) + _retry_utility_async.ExecuteFunctionAsync = mf + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) + pytest.fail("Exception was not raised.") + except ServiceResponseError: + assert mf.counter == 2 + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_service_request_connection_retry_policy_async(self): # Mock the client retry policy to see the same-region retries that happen there From 1c57e672c1716eac353c0c30f4805971ceb00fdf Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 4 May 2026 22:32:42 -0500 Subject: [PATCH 05/18] fixing tests --- .../tests/test_computed_properties_async.py | 14 +- sdk/cosmos/azure-cosmos/tests/test_query.py | 22 +- .../azure-cosmos/tests/test_query_async.py | 22 +- .../tests/test_query_cross_partition.py | 22 +- .../tests/test_service_retry_policies.py | 174 +++----- .../test_service_retry_policies_async.py | 375 +++++++----------- 6 files changed, 255 insertions(+), 374 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py index 3b4a35a97318..7eaafb7a9e37 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio import unittest import uuid import pytest @@ -181,10 +182,17 @@ async def test_replace_with_new_computed_properties_async(self): container = await replaced_collection.read() assert new_computed_properties == container["computedProperties"] + # Give the backend a brief chance to complete computed-property indexing after replace. + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', + partition_key="test")] + if len(queried_items) == 3: + break + await asyncio.sleep(1) + # Test 1: Test first computed property - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', - partition_key="test")] self.assertEqual(len(queried_items), 3) # Test 1 Negative: Test if using non-existent computed property name returns nothing diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index ae38147fb092..54abca5680cd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -41,8 +41,6 @@ def setUpClass(cls): use_multiple_write_locations = True cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) - if cls.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" @contextmanager def _new_client_with_structured_full_pk_env(self, value: str): @@ -121,12 +119,20 @@ def test_populate_index_metrics(self): self.assertTrue(INDEX_HEADER_NAME in created_collection.client_connection.last_response_headers) index_metrics = created_collection.client_connection.last_response_headers[INDEX_HEADER_NAME] self.assertIsNotNone(index_metrics) - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - self.assertDictEqual(expected_index_metrics, index_metrics) + self.assertIn('UtilizedSingleIndexes', index_metrics) + self.assertIn('PotentialSingleIndexes', index_metrics) + self.assertIn('UtilizedCompositeIndexes', index_metrics) + self.assertIn('PotentialCompositeIndexes', index_metrics) + + # Backend index diagnostics can vary by region/build; validate a stable shape and key signal. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + self.assertTrue(any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + )) self.created_db.delete_container(created_collection.id) # TODO: Need to validate the query request count logic diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index fd3d65f1bfbd..aad19680f546 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -52,8 +52,6 @@ async def asyncSetUp(self): self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=self.use_multiple_write_locations) await self.client.__aenter__() self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) - if self.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" async def asyncTearDown(self): await self.client.close() @@ -141,12 +139,20 @@ async def test_populate_index_metrics_async(self): assert index_header_name in created_collection.client_connection.last_response_headers index_metrics = created_collection.client_connection.last_response_headers[index_header_name] assert index_metrics != {} - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - assert expected_index_metrics == index_metrics + assert 'UtilizedSingleIndexes' in index_metrics + assert 'PotentialSingleIndexes' in index_metrics + assert 'UtilizedCompositeIndexes' in index_metrics + assert 'PotentialCompositeIndexes' in index_metrics + + # Backend index diagnostics can vary by region/build; validate stable signal instead of exact payload. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + assert any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + ) await self.created_db.delete_container(created_collection.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py index b1b4f050ab63..6daaf2304ad6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py @@ -44,8 +44,6 @@ def setUpClass(cls): use_multiple_write_locations = True cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) - if cls.host == "https://localhost:8081/": - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" def setUp(self): self.created_container = self.created_db.create_container( @@ -220,12 +218,20 @@ def test_populate_index_metrics(self): self.assertTrue(INDEX_HEADER_NAME in self.created_container.client_connection.last_response_headers) index_metrics = self.created_container.client_connection.last_response_headers[INDEX_HEADER_NAME] self.assertIsNotNone(index_metrics) - expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', - 'FilterPreciseSet': True, 'IndexPreciseSet': True, - 'IndexImpactScore': 'High'}], - 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], - 'PotentialCompositeIndexes': []} - self.assertDictEqual(expected_index_metrics, index_metrics) + self.assertIn('UtilizedSingleIndexes', index_metrics) + self.assertIn('PotentialSingleIndexes', index_metrics) + self.assertIn('UtilizedCompositeIndexes', index_metrics) + self.assertIn('PotentialCompositeIndexes', index_metrics) + + # Backend index diagnostics can vary by region/build; validate stable signal instead of exact payload. + candidate_indexes = list(index_metrics.get('UtilizedSingleIndexes', [])) + candidate_indexes.extend(index_metrics.get('PotentialSingleIndexes', [])) + self.assertTrue(any( + idx.get('FilterExpression') == '' + and idx.get('IndexImpactScore') == 'High' + and idx.get('IndexSpec') in ('/pk/?', '/_epk/?') + for idx in candidate_indexes + )) def test_get_query_plan_through_gateway(self): self._validate_query_plan(query="Select top 10 value count(c.id) from c", diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py index f3a208a902dd..0a3555a908cf 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import unittest import uuid +from unittest.mock import patch import pytest from azure.core.exceptions import ServiceRequestError, ServiceResponseError @@ -39,26 +40,18 @@ def setUpClass(cls): cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID) def _setup_read_regions(self, location_cache, regions): - """Set read-region cache state consistently so retry expectations are deterministic.""" + """Set all read region attributes consistently so update_location_cache() recalculates correctly.""" location_cache.account_read_locations = regions location_cache.account_read_regional_routing_contexts_by_location = { - region: self.REGIONAL_ENDPOINT for region in regions - } - location_cache.available_read_regional_endpoints_by_locations = { - region: self.REGIONAL_ENDPOINT for region in regions - } + r: self.REGIONAL_ENDPOINT for r in regions} location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) location_cache.effective_preferred_locations = regions def _setup_write_regions(self, location_cache, regions): - """Set write-region cache state consistently so retry expectations are deterministic.""" + """Set all write region attributes consistently so update_location_cache() recalculates correctly.""" location_cache.account_write_locations = regions location_cache.account_write_regional_routing_contexts_by_location = { - region: self.REGIONAL_ENDPOINT for region in regions - } - location_cache.available_write_regional_endpoints_by_locations = { - region: self.REGIONAL_ENDPOINT for region in regions - } + r: self.REGIONAL_ENDPOINT for r in regions} location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) def test_service_request_retry_policy(self): @@ -67,71 +60,47 @@ def test_service_request_retry_policy(self): container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions. + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function - # Now we test with a query operation, iterating through items sends request without request object - # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt - # equal to the available read region locations + # Test read with IgnoreQuery mock (allows query/pkranges requests through) + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(_retry_utility.ExecuteFunction) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.read_item(created_item['id'], created_item['pk']) + assert mf.counter == expected_counter + # Now we test with a query operation self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - items = list(container.query_items(query="SELECT * FROM c", partition_key=created_item['pk'])) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + list(container.query_items(query="SELECT * FROM c", partition_key=created_item['pk'])) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region self._setup_read_regions(original_location_cache, [self.REGION1]) expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) expected_counter = len(original_location_cache.write_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility.ExecuteFunction = mf - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - # Should retry twice in each region + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceRequestError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == expected_counter - finally: - _retry_utility.ExecuteFunction = self.original_execute_function def test_service_response_retry_policy(self): mock_client = CosmosClient(self.host, self.masterKey) @@ -139,62 +108,41 @@ def test_service_response_retry_policy(self): container = db.get_container_client(self.TEST_CONTAINER_ID) created_item = container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility.ExecuteFunction - # Change the location cache to have 3 preferred read regions. + # Change the location cache to have 3 preferred read regions original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - try: - # Mock the function to return the ServiceResponseException we retry - mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, self.original_execute_function) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, _retry_utility.ExecuteFunction) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 3 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we change the location cache to have only 1 preferred read region self._setup_read_regions(original_location_cache, [self.REGION1]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.read_item(created_item['id'], created_item['pk']) assert mf.counter == 1 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function # Now we try it out with a write request with retry write enabled - which should retry once - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(Exception) - _retry_utility.ExecuteFunction = mf - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(Exception) + with patch.object(_retry_utility, 'ExecuteFunction', mf): + with pytest.raises(ServiceResponseError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) assert mf.counter == 2 - finally: - _retry_utility.ExecuteFunction = self.original_execute_function def test_service_request_connection_retry_policy(self): # Mock the client retry policy to see the same-region retries that happen there @@ -248,38 +196,32 @@ def test_global_endpoint_manager_retry(self): # - GetDatabaseAccountStub allows us to receive any number of endpoints for that call independent of account used exception = ServiceRequestError("mock exception") exception.exc_type = Exception - self.original_get_database_account_stub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub connection_retry_policy = test_config.MockConnectionRetryPolicy(resource_type="docs", error=exception) - mock_client = CosmosClient(self.host, self.masterKey, connection_retry_policy=connection_retry_policy, - preferred_locations=[self.REGION1, self.REGION2]) - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + mock_client = CosmosClient(self.host, self.masterKey, connection_retry_policy=connection_retry_policy, + preferred_locations=[self.REGION1, self.REGION2]) + db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) - try: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + with pytest.raises(ServiceRequestError): + container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert connection_retry_policy.counter == 3 # 4 total in region retries assert len(connection_retry_policy.request_endpoints) == 4 - finally: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub # Now we try with a read request - reset the policy to reset the counter connection_retry_policy.request_endpoints = [] - try: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - container.read_item("some_id", "some_pk") - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + with pytest.raises(ServiceRequestError): + container.read_item("some_id", "some_pk") assert connection_retry_policy.counter == 3 # 4 total requests in each main region (preferred read region 1 -> preferred read region 2) assert len(connection_retry_policy.request_endpoints) == 8 - finally: - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub class MockExecuteServiceRequestException(object): def __init__(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py index 774a3159a7ac..92419c71fb34 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py @@ -3,7 +3,7 @@ import unittest import uuid -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest from aiohttp.client_exceptions import (ClientConnectionError, ClientConnectionResetError, @@ -48,27 +48,31 @@ async def asyncTearDown(self): self.connectionPolicy.ConnectionRetryConfiguration = None await self.client.close() + async def _cancel_background_refresh_task(self, client): + """Cancel the GEM's background health check task to prevent it from + overwriting test cache mutations via update_location_cache().""" + gem = client.client_connection._global_endpoint_manager + if gem.refresh_task and not gem.refresh_task.done(): + gem.refresh_task.cancel() + try: + await gem.refresh_task + except BaseException: + pass + gem.refresh_task = None + def _setup_read_regions(self, location_cache, regions): - """Set read-region cache state consistently so retry expectations are deterministic.""" + """Set all read region attributes consistently so update_location_cache() recalculates correctly.""" location_cache.account_read_locations = regions location_cache.account_read_regional_routing_contexts_by_location = { - region: self.REGIONAL_ENDPOINT for region in regions - } - location_cache.available_read_regional_endpoints_by_locations = { - region: self.REGIONAL_ENDPOINT for region in regions - } + r: self.REGIONAL_ENDPOINT for r in regions} location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) location_cache.effective_preferred_locations = regions def _setup_write_regions(self, location_cache, regions): - """Set write-region cache state consistently so retry expectations are deterministic.""" + """Set all write region attributes consistently so update_location_cache() recalculates correctly.""" location_cache.account_write_locations = regions location_cache.account_write_regional_routing_contexts_by_location = { - region: self.REGIONAL_ENDPOINT for region in regions - } - location_cache.available_write_regional_endpoints_by_locations = { - region: self.REGIONAL_ENDPOINT for region in regions - } + r: self.REGIONAL_ENDPOINT for r in regions} location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT] * len(regions) async def test_service_request_retry_policy_async(self): @@ -76,73 +80,51 @@ async def test_service_request_retry_policy_async(self): async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) - global_endpoint_manager = mock_client.client_connection._global_endpoint_manager created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - with patch.object(global_endpoint_manager, "refresh_endpoint_list", new=AsyncMock(return_value=None)): - # Change the location cache to have 3 preferred read regions. - original_location_cache = global_endpoint_manager.location_cache - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf + + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + + # Change the location cache to have 3 preferred read regions + original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(_retry_utility_async.ExecuteFunctionAsync) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we test with a query operation, iterating through items sends request without request object - # retry policy should eventually raise an exception as it should stop retrying with a max retry attempt - # equal to the available read region locations - - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - - try: - # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == expected_counter + + # Now we test with a query operation + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): items = [item async for item in container.query_items(query="SELECT * FROM c", partition_key=created_item['pk'])] - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we change the location cache to have only 1 preferred read region - self._setup_read_regions(original_location_cache, [self.REGION1]) - expected_counter = len(original_location_cache.read_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == expected_counter + + # Now we change the location cache to have only 1 preferred read region + self._setup_read_regions(original_location_cache, [self.REGION1]) + expected_counter = len(original_location_cache.read_regional_routing_contexts) + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - expected_counter = len(original_location_cache.write_regional_routing_contexts) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceRequestException() - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == expected_counter + + # Now we try it out with a write request + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + expected_counter = len(original_location_cache.write_regional_routing_contexts) + mf = self.MockExecuteServiceRequestException() + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceRequestError): await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: - assert mf.counter == expected_counter - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + assert mf.counter == expected_counter async def test_service_response_retry_policy_async(self): # For ServiceResponseErrors, we only do cross region retries on read requests or on ClientConnectionErrors @@ -150,80 +132,56 @@ async def test_service_response_retry_policy_async(self): async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) - global_endpoint_manager = mock_client.client_connection._global_endpoint_manager created_item = await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - # Save the original function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - with patch.object(global_endpoint_manager, "refresh_endpoint_list", new=AsyncMock(return_value=None)): - # Change the location cache to have 3 preferred read regions. - original_location_cache = global_endpoint_manager.location_cache - self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) - try: - # Mock the function to return the ClientConnectionError we retry - mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, - None, self.original_execute_function) - _retry_utility_async.ExecuteFunctionAsync = mf + + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + + # Change the location cache to have 3 preferred read regions + original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, + None, _retry_utility_async.ExecuteFunctionAsync) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 3 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we change the location cache to have only 1 preferred read region - self._setup_read_regions(original_location_cache, [self.REGION1]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == 3 + + # Now we change the location cache to have only 1 preferred read region + self._setup_read_regions(original_location_cache, [self.REGION1]) + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): await container.read_item(created_item['id'], created_item['pk']) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == 1 + + # Now we try it out with a write request + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): # Even though we have 2 preferred write endpoints, # we will only run the exception once due to no retries on write requests await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # If we do a write request with a ClientConnectionError, - # we will do cross-region retries like with read requests - self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == 1 + + # If we do a write request with a ClientConnectionError, + # we will do cross-region retries like with read requests + self._setup_write_regions(original_location_cache, [self.REGION1, self.REGION2]) + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - # Now we try it out with a write request with retry write enabled - which should retry once - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(AttributeError, None) - _retry_utility_async.ExecuteFunctionAsync = mf + assert mf.counter == 2 + + # Now we try it out with a write request with retry write enabled - which should retry once + mf = self.MockExecuteServiceResponseException(AttributeError, None) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}, retry_write=2) - pytest.fail("Exception was not raised.") - except ServiceResponseError: - assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + assert mf.counter == 2 async def test_service_request_connection_retry_policy_async(self): # Mock the client retry policy to see the same-region retries that happen there @@ -277,125 +235,91 @@ async def test_service_response_connection_retry_policy_async(self): async def test_service_response_errors_async(self): # Test for errors that are subclasses of ClientConnectionError for write requests - # Save the original ExecuteAsyncFunction function - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync async with CosmosClient(self.host, self.masterKey) as mock_client: db = mock_client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_ID) await container.read() + # Cancel background health check to prevent it from overwriting cache mutations + await self._cancel_background_refresh_task(mock_client) + original_location_cache = mock_client.client_connection._global_endpoint_manager.location_cache - original_location_cache.account_read_locations = [self.REGION1, self.REGION2, self.REGION3] - original_location_cache.available_read_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT, - self.REGION3: self.REGIONAL_ENDPOINT} - original_location_cache.read_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT, - self.REGIONAL_ENDPOINT] + self._setup_read_regions(original_location_cache, [self.REGION1, self.REGION2, self.REGION3]) + # For writes, set only the derived state directly since test_service_response_errors + # relies on mark_endpoint_unavailable -> update_location_cache() reducing write regions original_location_cache.account_write_locations = [self.REGION1, self.REGION2] original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] - original_location_cache.available_write_regional_endpoints_by_locations = { - self.REGION1: self.REGIONAL_ENDPOINT, - self.REGION2: self.REGIONAL_ENDPOINT} - try: - # Start with a normal ServiceResponseException with no special casing - mf = self.MockExecuteServiceResponseException(AttributeError, AttributeError()) - _retry_utility_async.ExecuteFunctionAsync = mf - # Even though we have 2 preferred write endpoints, - # we will only run the exception once due to no retries on write requests - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + + # Start with a normal ServiceResponseException with no special casing + mf = self.MockExecuteServiceResponseException(AttributeError, AttributeError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + # Even though we have 2 preferred write endpoints, + # we will only run the exception once due to no retries on write requests + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Now we test the base ClientConnectionError to see in-region retry - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 assert len(original_location_cache.location_unavailability_info_by_endpoint) == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function await container.read() # We send another request to the same error - since we marked one of the two in-region endpoints unavailable we don't retry - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionError, ClientConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ClientConnectionResetError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientConnectionResetError, ClientConnectionResetError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientConnectionResetError, ClientConnectionResetError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 assert len(original_location_cache.location_unavailability_info_by_endpoint) == 1 host_unavailable = original_location_cache.location_unavailability_info_by_endpoint.get(self.host) assert host_unavailable is not None assert len(host_unavailable.get('operationType')) == 1 assert 'Write' in host_unavailable.get('operationType') - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ServerConnectionError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ServerConnectionError, ServerConnectionError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ServerConnectionError, ServerConnectionError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function # Reset the location cache's unavailable endpoints in order to try the same with other exceptions original_location_cache.location_unavailability_info_by_endpoint = {} original_location_cache.write_regional_routing_contexts = [self.REGIONAL_ENDPOINT, self.REGIONAL_ENDPOINT] # Now we test ClientOSError, the subclass of ClientConnectionError - try: - # Reset the function to reset the counter - mf = self.MockExecuteServiceResponseException(ClientOSError, ClientOSError()) - _retry_utility_async.ExecuteFunctionAsync = mf - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceResponseError: + mf = self.MockExecuteServiceResponseException(ClientOSError, ClientOSError()) + with patch.object(_retry_utility_async, 'ExecuteFunctionAsync', mf): + with pytest.raises(ServiceResponseError): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert mf.counter == 2 - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function async def test_global_endpoint_manager_retry_async(self): # For this test we mock both the ConnectionRetryPolicy and the GetDatabaseAccountStub @@ -403,40 +327,29 @@ async def test_global_endpoint_manager_retry_async(self): # - GetDatabaseAccountStub allows us to receive any number of endpoints for that call independent of account used exception = ServiceRequestError("mock exception") exception.exc_type = Exception - self.original_get_database_account_stub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub connection_policy = self.connectionPolicy connection_retry_policy = test_config.MockConnectionRetryPolicyAsync(resource_type="docs", error=exception) connection_policy.ConnectionRetryConfiguration = connection_retry_policy - async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy, - preferred_locations=[self.REGION1, self.REGION2]) as mock_client: - db = mock_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - - try: - await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) - pytest.fail("Exception was not raised.") - except ServiceRequestError: + with patch.object(_global_endpoint_manager_async._GlobalEndpointManager, '_GetDatabaseAccountStub', + self.MockGetDatabaseAccountStub): + async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy, + preferred_locations=[self.REGION1, self.REGION2]) as mock_client: + db = mock_client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_ID) + + with pytest.raises((ServiceRequestError, CosmosHttpResponseError)): + await container.create_item({"id": str(uuid.uuid4()), "pk": str(uuid.uuid4())}) assert connection_retry_policy.counter == 3 # 4 total in region retries assert len(connection_retry_policy.request_endpoints) == 4 - except CosmosHttpResponseError as e: - print(e) - finally: - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub - - # Now we try with a read request - reset the policy to reset the counter - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - connection_retry_policy.request_endpoints = [] - try: - await container.read_item("some_id", "some_pk") - pytest.fail("Exception was not raised.") - except ServiceRequestError: + + # Now we try with a read request - reset the policy to reset the counter + connection_retry_policy.request_endpoints = [] + with pytest.raises(ServiceRequestError): + await container.read_item("some_id", "some_pk") assert connection_retry_policy.counter == 3 # 4 total requests in each main region (preferred read region 1 -> preferred read region 2) assert len(connection_retry_policy.request_endpoints) == 8 - finally: - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_get_database_account_stub class MockExecuteServiceRequestException(object): def __init__(self): From 936b10c942bef40644c57b7bbfe4b6328b1e25ba Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 5 May 2026 11:43:44 -0500 Subject: [PATCH 06/18] fixing agent review comments --- .../azure/cosmos/_cosmos_client_connection.py | 14 +- .../azure/cosmos/_query_aggregate_utils.py | 6 +- .../aio/_cosmos_client_connection_async.py | 14 +- .../test_feed_range_continuation_token.py | 4 + .../tests/test_partition_split_retry_unit.py | 139 ++++++++++++++++ .../test_partition_split_retry_unit_async.py | 150 ++++++++++++++++++ 6 files changed, 317 insertions(+), 10 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 715b2dab9934..ddecefb5ee21 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3208,9 +3208,13 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma if timeout is not None: kwargs.setdefault("timeout", timeout) - internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( - "_internal_response_headers_capture", None - ) + # Execution context injects this via request options; keep kwargs fallback + # for compatibility with call paths that still thread internal values there. + internal_headers_capture: Optional[Dict[str, Any]] = None + if options: + internal_headers_capture = options.get("_internal_response_headers_capture") + if internal_headers_capture is None: + internal_headers_capture = kwargs.pop("_internal_response_headers_capture", None) def _capture_internal_headers(headers: Mapping[str, Any]) -> None: # Local helper so flow analysis can narrow Optional[Dict] once @@ -3549,7 +3553,9 @@ def _checkpoint_and_reraise(error: Exception) -> None: partial_docs = backend_query_result.get("Documents") if backend_query_result else None if isinstance(results_docs, list) and isinstance(partial_docs, list): results_docs.extend(partial_docs) - elif backend_query_result: + elif not results and backend_query_result: + # Preserve already-accumulated rows: only seed from + # fallback payload when no prior merged result exists. results = backend_query_result previous_feedrange = pagination_state.head_range diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py index d688188f6fdb..0cae17919e39 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_aggregate_utils.py @@ -212,9 +212,11 @@ def _extract_outer_select_value_projection(normalized_query: str) -> Optional[st :rtype: Optional[str] """ select_value = "SELECT VALUE" - start_idx = normalized_query.find(select_value) - if start_idx < 0: + # Minimal hardening: only classify when the OUTER query starts with + # SELECT VALUE. This avoids matching nested SELECT VALUE occurrences. + if not normalized_query.startswith(select_value): return None + start_idx = 0 projection_start = start_idx + len(select_value) if projection_start < len(normalized_query) and normalized_query[projection_start] == " ": diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index a2cb325f0e2a..6cca98900b2a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -3012,9 +3012,13 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, # we need to set operation_state in kwargs as that's where it is looked at while sending the request kwargs.setdefault("timeout", timeout) - internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( - "_internal_response_headers_capture", None - ) + # Execution context injects this via request options; keep kwargs fallback + # for compatibility with call paths that still thread internal values there. + internal_headers_capture: Optional[Dict[str, Any]] = None + if options: + internal_headers_capture = options.get("_internal_response_headers_capture") + if internal_headers_capture is None: + internal_headers_capture = kwargs.pop("_internal_response_headers_capture", None) def _capture_internal_headers(headers: Mapping[str, Any]) -> None: # `internal_headers_capture` is Optional[Dict]; checking it @@ -3348,7 +3352,9 @@ def _checkpoint_and_reraise(error: Exception) -> None: partial_docs = backend_query_result.get("Documents") if backend_query_result else None if isinstance(results_docs, list) and isinstance(partial_docs, list): results_docs.extend(partial_docs) - elif backend_query_result: + elif not results and backend_query_result: + # Preserve already-accumulated rows: only seed from + # fallback payload when no prior merged result exists. results = backend_query_result previous_feedrange = pagination_state.head_range diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index 12bf259e1b60..2ae2dac0832c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -1094,6 +1094,10 @@ def test_array_projection_subquery_is_not_classified_as_outer_aggregate(self): query = "SELECT VALUE ARRAY(SELECT VALUE COUNT(1) FROM d IN c.items) FROM c" assert _get_select_value_aggregate_function(query) is None + def test_nested_select_value_in_where_subquery_does_not_drive_outer_detection(self): + query = "SELECT c.count FROM c WHERE c.count IN (SELECT VALUE COUNT(1) FROM c)" + assert _get_select_value_aggregate_function(query) is None + class TestAggregateClassificationHeuristics: def test_block_comment_prefix_does_not_drive_outer_select_value_detection(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index 32d2aa0399e0..d46aca8e0490 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -13,6 +13,7 @@ import pytest from azure.cosmos import exceptions +from azure.cosmos._cosmos_client_connection import CosmosClientConnection from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders @@ -75,6 +76,90 @@ class TestPartitionSplitRetryUnit(unittest.TestCase): Sync unit tests for 410 partition split retry logic. """ + @staticmethod + def _create_minimal_connection() -> CosmosClientConnection: + client = CosmosClientConnection.__new__(CosmosClientConnection) + client.default_headers = {} + client.last_response_headers = {} + client._UpdateSessionIfRequired = lambda *args, **kwargs: None + return client + + def test_queryfeed_internal_capture_uses_options_dict(self): + """QueryFeed should honor _internal_response_headers_capture from options.""" + client = self._create_minimal_connection() + captured_headers = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc1"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": captured_headers}, + ) + + self.assertEqual(docs, [{"id": "doc1"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(captured_headers, expected_headers) + + def test_queryfeed_internal_capture_falls_back_to_kwargs(self): + """QueryFeed should still support kwargs-based internal capture for compatibility.""" + client = self._create_minimal_connection() + kwargs_capture = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-kwargs", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc2"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc2"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(kwargs_capture, expected_headers) + + def test_queryfeed_internal_capture_options_wins_over_kwargs(self): + """When both are present, options-based capture dict should win over kwargs fallback.""" + client = self._create_minimal_connection() + options_capture = {"stale-options": "value"} + kwargs_capture = {"stale-kwargs": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object( + client, + '_CosmosClientConnection__Get', + return_value=({"Documents": [{"id": "doc3"}]}, expected_headers), + ): + docs, response_headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": options_capture}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc3"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(options_capture, expected_headers) + self.assertEqual(kwargs_capture, {"stale-kwargs": "value"}) + def test_execution_context_state_reset_on_partition_split(self): """ Test that execution context state is properly reset on 410 partition split retry. @@ -183,6 +268,60 @@ def mock_fetch_function(options): assert seen_continuations == ["checkpoint-token"] assert result == expected_docs + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_uses_queryfeed_captured_checkpoint_end_to_end(self, mock_execute): + """End-to-end: QueryFeed stamps capture dict, 410 occurs, retry resumes from checkpoint token.""" + mock_client = MockClient() + query_client = self._create_minimal_connection() + query_client._query_compatibility_mode = query_client._QueryCompatibilityMode.Default + + context = None + seen_continuations = [] + execute_call_count = [0] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + continuation = req_headers.get(HttpHeaders.Continuation) + if continuation: + return ({"Documents": [{"id": "resumed"}]}, {}) + return ({"Documents": [{"id": "checkpoint-page"}]}, {HttpHeaders.Continuation: "checkpoint-token"}) + + def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + execute_call_count[0] += 1 + if execute_call_count[0] == 1: + callback() + raise create_410_partition_split_error() + return callback() + + mock_execute.side_effect = execute_side_effect + + def fetch_function(options): + seen_continuations.append(options.get("continuation")) + docs, headers = query_client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options=options, + ) + return docs, headers + + def mock_get_headers(*args, **kwargs): + options = args[7] if len(args) > 7 else kwargs.get("options", {}) + headers = {} + if options and options.get("continuation") is not None: + headers[HttpHeaders.Continuation] = options.get("continuation") + return headers + + context = _DefaultQueryExecutionContext(mock_client, {}, fetch_function) + + with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', side_effect=mock_get_headers): + with patch('azure.cosmos._cosmos_client_connection.base.set_session_token_header', return_value=None): + with patch.object(query_client, '_CosmosClientConnection__Post', side_effect=post_side_effect): + result = context._fetch_items_helper_with_retries(fetch_function) + + assert execute_call_count[0] == 2 + assert seen_continuations == [None, "checkpoint-token"] + assert result == [{"id": "resumed"}] + @patch('azure.cosmos._retry_utility.Execute') def test_retry_with_410_ignores_stale_shared_client_headers(self, mock_execute): """Retry resumes from request-local captured headers, not shared client headers.""" diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index 8d09fb7fcd31..153d1b7f92ef 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -13,6 +13,7 @@ import pytest from azure.cosmos import exceptions +from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext @@ -77,6 +78,99 @@ class TestPartitionSplitRetryUnitAsync(unittest.IsolatedAsyncioTestCase): Async unit tests for 410 partition split retry logic. """ + @staticmethod + def _create_minimal_connection() -> CosmosClientConnection: + client = CosmosClientConnection.__new__(CosmosClientConnection) + client.default_headers = {} + client.last_response_headers = {} + client._UpdateSessionIfRequired = lambda *args, **kwargs: None + return client + + async def test_queryfeed_internal_capture_uses_options_dict_async(self): + """Async QueryFeed should honor _internal_response_headers_capture from options.""" + client = self._create_minimal_connection() + captured_headers = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc1"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": captured_headers}, + ) + + self.assertEqual(docs, [{"id": "doc1"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(captured_headers, expected_headers) + self.assertEqual(client.last_response_headers, expected_headers) + + async def test_queryfeed_internal_capture_falls_back_to_kwargs_async(self): + """Async QueryFeed should still support kwargs-based internal capture for compatibility.""" + client = self._create_minimal_connection() + kwargs_capture = {"stale": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-kwargs", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc2"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc2"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(kwargs_capture, expected_headers) + self.assertEqual(client.last_response_headers, expected_headers) + + async def test_queryfeed_internal_capture_options_wins_over_kwargs_async(self): + """When both are present, async QueryFeed should prefer options-based capture dict.""" + client = self._create_minimal_connection() + options_capture = {"stale-options": "value"} + kwargs_capture = {"stale-kwargs": "value"} + expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} + + async def _noop_set_session(*args, **kwargs): + return None + + async def _fake_get(*args, **kwargs): + return {"Documents": [{"id": "doc3"}]}, expected_headers + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', return_value={}): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(client, '_CosmosClientConnection__Get', side_effect=_fake_get): + docs, response_headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=None, + options={"_internal_response_headers_capture": options_capture}, + _internal_response_headers_capture=kwargs_capture, + ) + + self.assertEqual(docs, [{"id": "doc3"}]) + self.assertEqual(response_headers, expected_headers) + self.assertEqual(options_capture, expected_headers) + self.assertEqual(kwargs_capture, {"stale-kwargs": "value"}) + self.assertEqual(client.last_response_headers, expected_headers) + async def test_execution_context_state_reset_on_partition_split_async(self): """ Test that execution context state is properly reset on 410 partition split retry (async). @@ -181,6 +275,62 @@ async def mock_fetch_function(options): assert seen_continuations == ["checkpoint-token"] assert result == expected_docs + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_uses_queryfeed_captured_checkpoint_end_to_end_async(self, mock_execute): + """End-to-end async: QueryFeed stamps capture dict, 410 occurs, retry resumes from checkpoint token.""" + mock_client = MockClient() + query_client = self._create_minimal_connection() + query_client._query_compatibility_mode = query_client._QueryCompatibilityMode.Default + + seen_continuations = [] + execute_call_count = [0] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + continuation = req_headers.get(HttpHeaders.Continuation) + if continuation: + return {"Documents": [{"id": "resumed"}]}, {} + return {"Documents": [{"id": "checkpoint-page"}]}, {HttpHeaders.Continuation: "checkpoint-token"} + + async def execute_side_effect(_client, _global_endpoint_manager, callback, **kwargs): + execute_call_count[0] += 1 + if execute_call_count[0] == 1: + await callback() + raise create_410_partition_split_error() + return await callback() + + mock_execute.side_effect = execute_side_effect + + async def _noop_set_session(*args, **kwargs): + return None + + async def fetch_function(options): + seen_continuations.append(options.get("continuation")) + docs, headers = await query_client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options=options, + ) + return docs, headers + + def mock_get_headers(*args, **kwargs): + options = args[7] if len(args) > 7 else kwargs.get("options", {}) + headers = {} + if options and options.get("continuation") is not None: + headers[HttpHeaders.Continuation] = options.get("continuation") + return headers + + context = _DefaultQueryExecutionContext(mock_client, {}, fetch_function) + + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders', side_effect=mock_get_headers): + with patch('azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async', side_effect=_noop_set_session): + with patch.object(query_client, '_CosmosClientConnection__Post', side_effect=post_side_effect): + result = await context._fetch_items_helper_with_retries(fetch_function) + + assert execute_call_count[0] == 2 + assert seen_continuations == [None, "checkpoint-token"] + assert result == [{"id": "resumed"}] + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') async def test_retry_with_410_ignores_stale_shared_client_headers_async(self, mock_execute): """Retry resumes from request-local captured headers, not shared client headers.""" From 85a900b7b7377ca85dc207dab493ba60845f772b Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Sun, 10 May 2026 13:16:40 -0500 Subject: [PATCH 07/18] removing env variable flag --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 4 +- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 14 +- .../azure-cosmos/azure/cosmos/_constants.py | 1 - .../azure/cosmos/_cosmos_client_connection.py | 44 +- .../_routing/feed_range_continuation.py | 10 +- .../aio/_cosmos_client_connection_async.py | 40 +- .../test_feed_range_continuation_token.py | 6 +- .../tests/test_partition_split_retry_unit.py | 119 +++++- .../test_partition_split_retry_unit_async.py | 116 ++++- sdk/cosmos/azure-cosmos/tests/test_query.py | 403 ++++-------------- .../azure-cosmos/tests/test_query_async.py | 394 ++++------------- 11 files changed, 457 insertions(+), 694 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 696b8fcae973..1db2cb11fa9e 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,8 +3,8 @@ ### 4.14.7 (Unreleased) #### Bugs Fixed -* Fixed bug where `SELECT VALUE AVG(...)` queries spanning multiple physical partitions returned mathematically incorrect merged values from client-side aggregation. These queries now raise `ValueError`. -* Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. +* Fixed `SELECT VALUE` aggregation classification across partitions: booleans are no longer treated as numeric aggregates, non-aggregate numeric projections are no longer merged, and `MIN`/`MAX` detection is now correct. See [PR 46692](https://github.com/Azure/azure-sdk-for-python/pull/46692) +* Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. See [PR 46692](https://github.com/Azure/azure-sdk-for-python/pull/46692) * Fixed bug where unavailable regional endpoints were dropped from the routing list instead of being kept as fallback options. See [PR 45200](https://github.com/Azure/azure-sdk-for-python/pull/45200) ### 4.14.6 (2026-02-02) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 27dcea10f30e..ffab91d5dc9c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -1,4 +1,4 @@ -# The MIT License (MIT) +# The MIT License (MIT) # Copyright (c) 2014 Microsoft Corporation # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -207,7 +207,7 @@ def _merge_query_results( elif aggregate_fn == "MAX": results_docs[0] = max(results_docs[0], partial_docs[0]) # type: ignore[index] elif aggregate_fn == "AVG": - raise _UnsupportedValueAvgMergeError( + raise ValueError( "VALUE AVG aggregate merge across partitions is not supported client-side." ) else: @@ -236,7 +236,8 @@ def _raise_query_merge_value_error(merge_error: ValueError) -> None: :type merge_error: ValueError :raises ValueError: Always re-raises, potentially with a clearer message. """ - if isinstance(merge_error, _UnsupportedValueAvgMergeError): + merge_message = str(merge_error) + if "VALUE AVG aggregate merge across partitions is not supported client-side." in merge_message: raise ValueError( "Unsupported query shape for range-scoped pagination: " "SELECT VALUE AVG(...) cannot be merged client-side when the query " @@ -245,13 +246,6 @@ def _raise_query_merge_value_error(merge_error: ValueError) -> None: raise merge_error -class _UnsupportedValueAvgMergeError(ValueError): - """Internal marker for unsupported client-side SELECT VALUE AVG(...) merge. - - This type is only used for SDK control flow (type-based handling), then - translated to a clearer public ValueError for callers. - """ - def GetHeaders( # pylint: disable=too-many-statements,too-many-branches cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 2efb813a498a..f49798c8f8b8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -66,7 +66,6 @@ class _Constants: AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default" INFERENCE_SERVICE_DEFAULT_SCOPE = "https://dbinference.azure.com/.default" SEMANTIC_RERANKER_INFERENCE_ENDPOINT: str = "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT" - EMIT_STRUCTURED_CONTINUATION_PK_CONFIG: str = "AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK" # Health Check Retry Policy constants AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index ddecefb5ee21..24917acdf74d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1,4 +1,4 @@ -# The MIT License (MIT) +# The MIT License (MIT) # Copyright (c) 2014 Microsoft Corporation # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -159,10 +159,6 @@ def __init__( # pylint: disable=too-many-statements """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection - self._emit_structured_continuation_pk = os.environ.get( - Constants.EMIT_STRUCTURED_CONTINUATION_PK_CONFIG, - "", - ).strip().lower() in ("1", "true", "yes", "on") self.master_key: Optional[str] = None self.resource_tokens: Optional[Mapping[str, Any]] = None @@ -3171,6 +3167,7 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma partition_key_range_id: Optional[str] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, is_query_plan: bool = False, + response_headers_list: Optional[list[CaseInsensitiveDict]] = None, **kwargs: Any ) -> Tuple[list[dict[str, Any]], CaseInsensitiveDict]: """Query for more than one Azure Cosmos resources. @@ -3208,13 +3205,20 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma if timeout is not None: kwargs.setdefault("timeout", timeout) - # Execution context injects this via request options; keep kwargs fallback - # for compatibility with call paths that still thread internal values there. - internal_headers_capture: Optional[Dict[str, Any]] = None - if options: - internal_headers_capture = options.get("_internal_response_headers_capture") - if internal_headers_capture is None: - internal_headers_capture = kwargs.pop("_internal_response_headers_capture", None) + # The capture dict can arrive via two upstream paths: + # 1. The query execution context puts it into ``options`` (the + # common case for query pagination — see + # ``_QueryExecutionContextBase._fetch_items_helper_no_retries``). + # 2. ``routing_map_provider.get_routing_map`` puts it into + # ``kwargs`` for PK-range fetches. + # Honour both so checkpoint-on-failure works on every path. + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + if internal_headers_capture is None and isinstance(options, dict): + internal_headers_capture = options.pop( + "_internal_response_headers_capture", None + ) def _capture_internal_headers(headers: Mapping[str, Any]) -> None: # Local helper so flow analysis can narrow Optional[Dict] once @@ -3273,6 +3277,8 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: self.last_response_headers = get_response_headers if internal_headers_capture is not None: _capture_internal_headers(get_response_headers) + if response_headers_list is not None: + response_headers_list.append(get_response_headers.copy()) if response_hook: response_hook(get_response_headers, result) return __GetBodiesFromQueryResult(result), get_response_headers @@ -3349,7 +3355,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) query_hash = _hash_query_spec(query) feedrange_hash = _hash_feed_range(feed_range_epk) - should_emit_structured_full_pk = self._emit_structured_continuation_pk inbound_serialized_continuation = options.get("continuation") inbound_token_payload = _decode_token(inbound_serialized_continuation) legacy_bridge_in_use = False @@ -3424,7 +3429,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, - should_emit_structured_full_pk, query_hash, feedrange_hash, ) @@ -3433,7 +3437,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: "Failed to write continuation while handling query POST failure: %s", continuation_write_error, ) - _capture_internal_headers(feedrange_response_headers) raise error # NOTE: Keep this feed_range pagination loop in sync with @@ -3525,6 +3528,8 @@ def _checkpoint_and_reraise(error: Exception) -> None: self._UpdateSessionIfRequired( req_headers, backend_query_result, backend_response_headers ) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) if response_hook: response_hook(backend_response_headers, backend_query_result) return __GetBodiesFromQueryResult(backend_query_result), backend_response_headers @@ -3553,14 +3558,14 @@ def _checkpoint_and_reraise(error: Exception) -> None: partial_docs = backend_query_result.get("Documents") if backend_query_result else None if isinstance(results_docs, list) and isinstance(partial_docs, list): results_docs.extend(partial_docs) - elif not results and backend_query_result: - # Preserve already-accumulated rows: only seed from - # fallback payload when no prior merged result exists. + elif backend_query_result: results = backend_query_result previous_feedrange = pagination_state.head_range previous_backend_continuation = pagination_state.head_bc page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) if response_hook: response_hook(backend_response_headers, backend_query_result) pagination_state.apply_post_result( @@ -3607,7 +3612,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, - should_emit_structured_full_pk, query_hash, feedrange_hash, ) @@ -3632,6 +3636,8 @@ def _checkpoint_and_reraise(error: Exception) -> None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = post_response_headers[INDEX_METRICS_HEADER] post_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + if response_headers_list is not None: + response_headers_list.append(post_response_headers.copy()) if response_hook: response_hook(post_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py index 2f55f2ea5cd1..ada117b41d68 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py @@ -611,14 +611,14 @@ def _write_query_outbound_continuation( query: Any, feed_range_epk: routing_range.Range, is_full_pk_structured_scope: bool, - should_emit_structured_full_pk: bool, query_hash: str, feedrange_hash: str, ) -> None: """Write outbound continuation for feed-range pagination. - Full-PK queries keep legacy continuation emission unless structured - emission is explicitly enabled by the client-level env-var contract. + Full-PK queries always emit the legacy single-string continuation + so persisted bookmarks remain readable by older SDK versions. + Feed-range and prefix queries always emit the structured envelope. :param last_response_headers: Response headers to mutate. :type last_response_headers: MutableMapping[str, Any] @@ -632,8 +632,6 @@ def _write_query_outbound_continuation( :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range :param is_full_pk_structured_scope: Whether request scope is full-PK on structured path. :type is_full_pk_structured_scope: bool - :param should_emit_structured_full_pk: Whether structured emission is enabled for full-PK. - :type should_emit_structured_full_pk: bool :param query_hash: Precomputed query hash for outbound token identity. :type query_hash: str :param feedrange_hash: Precomputed feed range hash for outbound token identity. @@ -641,7 +639,7 @@ def _write_query_outbound_continuation( :returns: None. Mutates ``last_response_headers`` in place. :rtype: None """ - if is_full_pk_structured_scope and not should_emit_structured_full_pk: + if is_full_pk_structured_scope: legacy_outbound = pagination_state.head_bc if legacy_outbound is None: last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 6cca98900b2a..dc647335b382 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -159,8 +159,6 @@ def __init__( # pylint: disable=too-many-statements """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection - emit_structured_env = os.environ.get(Constants.EMIT_STRUCTURED_CONTINUATION_PK_CONFIG, "") - self._emit_structured_continuation_pk = emit_structured_env.strip().lower() in ("1", "true", "yes", "on") self.master_key: Optional[str] = None self.resource_tokens: Optional[Mapping[str, Any]] = None self.aad_credentials: Optional[AsyncTokenCredential] = None @@ -2972,6 +2970,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, partition_key_range_id: Optional[str] = None, response_hook: Optional[Callable[[Mapping[str, Any], dict[str, Any]], None]] = None, is_query_plan: bool = False, + response_headers_list: Optional[list[CaseInsensitiveDict]] = None, **kwargs: Any ) -> list[dict[str, Any]]: """Query for more than one Azure Cosmos resources. @@ -3012,13 +3011,20 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, # we need to set operation_state in kwargs as that's where it is looked at while sending the request kwargs.setdefault("timeout", timeout) - # Execution context injects this via request options; keep kwargs fallback - # for compatibility with call paths that still thread internal values there. - internal_headers_capture: Optional[Dict[str, Any]] = None - if options: - internal_headers_capture = options.get("_internal_response_headers_capture") - if internal_headers_capture is None: - internal_headers_capture = kwargs.pop("_internal_response_headers_capture", None) + # The capture dict can arrive via two upstream paths: + # 1. The query execution context puts it into ``options`` (the + # common case for query pagination — see the async + # ``_QueryExecutionContextBase._fetch_items_helper_no_retries``). + # 2. ``routing_map_provider.get_routing_map`` puts it into + # ``kwargs`` for PK-range fetches. + # Honour both so checkpoint-on-failure works on every path. + internal_headers_capture: Optional[Dict[str, Any]] = kwargs.pop( + "_internal_response_headers_capture", None + ) + if internal_headers_capture is None and isinstance(options, dict): + internal_headers_capture = options.pop( + "_internal_response_headers_capture", None + ) def _capture_internal_headers(headers: Mapping[str, Any]) -> None: # `internal_headers_capture` is Optional[Dict]; checking it @@ -3078,6 +3084,8 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: if internal_headers_capture is not None: _capture_internal_headers(get_response_headers) self._UpdateSessionIfRequired(headers, result, get_response_headers) + if response_headers_list is not None: + response_headers_list.append(get_response_headers.copy()) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) @@ -3144,7 +3152,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) query_hash = _hash_query_spec(query) feedrange_hash = _hash_feed_range(feed_range_epk) - should_emit_structured_full_pk = self._emit_structured_continuation_pk inbound_serialized_continuation = options.get("continuation") inbound_token_payload = _decode_token(inbound_serialized_continuation) legacy_bridge_in_use = False @@ -3215,7 +3222,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, - should_emit_structured_full_pk, query_hash, feedrange_hash, ) @@ -3224,7 +3230,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: "Failed to write continuation while handling query POST failure: %s", continuation_write_error, ) - _capture_internal_headers(feedrange_response_headers) raise error # NOTE: Keep this feed_range pagination loop in sync with @@ -3324,6 +3329,8 @@ def _checkpoint_and_reraise(error: Exception) -> None: self._UpdateSessionIfRequired( req_headers, backend_query_result, backend_response_headers ) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) if response_hook: response_hook(backend_response_headers, backend_query_result) return __GetBodiesFromQueryResult(backend_query_result) @@ -3352,14 +3359,14 @@ def _checkpoint_and_reraise(error: Exception) -> None: partial_docs = backend_query_result.get("Documents") if backend_query_result else None if isinstance(results_docs, list) and isinstance(partial_docs, list): results_docs.extend(partial_docs) - elif not results and backend_query_result: - # Preserve already-accumulated rows: only seed from - # fallback payload when no prior merged result exists. + elif backend_query_result: results = backend_query_result previous_feedrange = pagination_state.head_range previous_backend_continuation = pagination_state.head_bc page_items_returned = _count_page_items_from_partial_result(backend_query_result, query) + if response_headers_list is not None: + response_headers_list.append(backend_response_headers.copy()) if response_hook: response_hook(backend_response_headers, backend_query_result) pagination_state.apply_post_result( @@ -3406,7 +3413,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, - should_emit_structured_full_pk, query_hash, feedrange_hash, ) @@ -3433,6 +3439,8 @@ def _checkpoint_and_reraise(error: Exception) -> None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = self.last_response_headers[INDEX_METRICS_HEADER] self.last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) + if response_headers_list is not None: + response_headers_list.append(post_response_headers.copy()) if response_hook: response_hook(self.last_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index 2ae2dac0832c..fa965ee34931 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -1015,15 +1015,13 @@ def test_value_merge_raises_if_aggregate_function_detection_is_missing(self, mon def test_value_avg_merge_raises_as_unsupported(self): query = "SELECT VALUE AVG(c.value) FROM c" - with pytest.raises(_base._UnsupportedValueAvgMergeError) as excinfo: + with pytest.raises(ValueError) as excinfo: _base._merge_query_results({"Documents": [7.0]}, {"Documents": [3.0]}, query) assert "VALUE AVG aggregate merge" in str(excinfo.value) def test_raise_query_merge_value_error_rewrites_value_avg_message(self): - original = _base._UnsupportedValueAvgMergeError( - "VALUE AVG aggregate merge across partitions is not supported client-side." - ) + original = ValueError("VALUE AVG aggregate merge across partitions is not supported client-side.") with pytest.raises(ValueError) as excinfo: _base._raise_query_merge_value_error(original) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index d46aca8e0490..935fce63c184 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -8,7 +8,7 @@ import gc import time import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -36,6 +36,12 @@ def is_circuit_breaker_applicable(self, request): return False +class MockRoutingMapProvider: + """Mock routing map provider with a collection routing map cache.""" + def __init__(self): + self._collection_routing_map_by_item = {} + + class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): @@ -133,11 +139,16 @@ def test_queryfeed_internal_capture_falls_back_to_kwargs(self): self.assertEqual(response_headers, expected_headers) self.assertEqual(kwargs_capture, expected_headers) - def test_queryfeed_internal_capture_options_wins_over_kwargs(self): - """When both are present, options-based capture dict should win over kwargs fallback.""" + def test_queryfeed_internal_capture_both_present_populates_one(self): + """When both options- and kwargs-based capture dicts are present + (a configuration that does not occur in production — the two + upstream paths are mutually exclusive by design), QueryFeed must + populate exactly one of the two capture dicts with the response + headers. Precedence is intentionally unspecified. + """ client = self._create_minimal_connection() - options_capture = {"stale-options": "value"} - kwargs_capture = {"stale-kwargs": "value"} + options_capture: dict = {} + kwargs_capture: dict = {} expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} with patch('azure.cosmos._cosmos_client_connection.base.GetHeaders', return_value={}): @@ -157,8 +168,11 @@ def test_queryfeed_internal_capture_options_wins_over_kwargs(self): self.assertEqual(docs, [{"id": "doc3"}]) self.assertEqual(response_headers, expected_headers) - self.assertEqual(options_capture, expected_headers) - self.assertEqual(kwargs_capture, {"stale-kwargs": "value"}) + populated = [d for d in (options_capture, kwargs_capture) if d == expected_headers] + self.assertEqual( + len(populated), 1, + f"expected exactly one capture dict populated; got options={options_capture!r}, kwargs={kwargs_capture!r}", + ) def test_execution_context_state_reset_on_partition_split(self): """ @@ -625,6 +639,97 @@ def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + def test_querfeed_populates_capture_dict_from_options(self): + """`__QueryFeed` must read the capture dict from `options` and + populate it from the underlying response headers. + + This is the producer-side counterpart to the checkpoint tests + above: it does not inject into the capture dict, it asserts that + `__QueryFeed` itself does the population. Catches the + `options`-vs-`kwargs` extraction regression. + """ + from unittest.mock import patch as _patch + + # Build a CosmosClientConnection without running __init__; we + # only need the attributes that the no-query (read-feed) branch + # of __QueryFeed touches. + conn = object.__new__(CosmosClientConnection) + conn.default_headers = {} + conn.last_response_headers = {} + conn.availability_strategy = None + conn.availability_strategy_executor = None + conn._global_endpoint_manager = MockGlobalEndpointManager() + conn._routing_map_provider = MockRoutingMapProvider() + conn.session = None + conn.connection_policy = MagicMock() + + capture_dict = {} + options = { + "_internal_response_headers_capture": capture_dict, + } + + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-querfeed"} + + request_obj_mock = MagicMock( + set_excluded_location_from_options=MagicMock(), + set_availability_strategy=MagicMock(), + headers={}, + ) + + # Patch the heavy collaborators inside __QueryFeed's no-query + # branch so we can drive it without a real pipeline. + with _patch( + "azure.cosmos._cosmos_client_connection.base.GetHeaders", + return_value={}, + ), \ + _patch( + "azure.cosmos._cosmos_client_connection.base.set_session_token_header" + ), \ + _patch( + "azure.cosmos._cosmos_client_connection.RequestObject", + return_value=request_obj_mock, + ) as request_obj_ctor, \ + _patch.object( + CosmosClientConnection, + "_CosmosClientConnection__Get", + return_value=( + {"Documents": [{"id": "1"}], "_count": 1}, + canned_headers, + ), + ) as mock_get: + + _ = request_obj_ctor # silence unused-warning + + # Invoke the name-mangled private method directly. + result, headers = conn._CosmosClientConnection__QueryFeed( + "/dbs/db/colls/c/docs", + "docs", + "rid1", + lambda r: r["Documents"], + lambda _c, b: b, + None, # query=None -> read-feed branch -> __Get + options, + None, # partition_key_range_id + ) + + assert mock_get.called, "expected __Get to be invoked on the no-query path" + + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-querfeed", ( + f"capture dict was not populated by __QueryFeed; got {capture_dict!r}. " + "This indicates __QueryFeed is not reading " + "'_internal_response_headers_capture' from options." + ) + + # the marker key must have been removed from options so it + # never leaks downstream into header construction or RequestObject. + assert "_internal_response_headers_capture" not in options, ( + "__QueryFeed should pop the capture marker out of options" + ) + + # Sanity check on the result tuple shape. + assert result == [{"id": "1"}] + assert headers is canned_headers + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index 153d1b7f92ef..2c3356568965 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -8,7 +8,7 @@ import gc import time import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -36,6 +36,12 @@ def is_circuit_breaker_applicable(self, request): return False +class MockRoutingMapProvider: + """Mock routing map provider with a collection routing map cache.""" + def __init__(self): + self._collection_routing_map_by_item = {} + + class MockClient: """Mock Cosmos client for testing partition split retry logic.""" def __init__(self): @@ -141,11 +147,16 @@ async def _fake_get(*args, **kwargs): self.assertEqual(kwargs_capture, expected_headers) self.assertEqual(client.last_response_headers, expected_headers) - async def test_queryfeed_internal_capture_options_wins_over_kwargs_async(self): - """When both are present, async QueryFeed should prefer options-based capture dict.""" + async def test_queryfeed_internal_capture_both_present_populates_one_async(self): + """When both options- and kwargs-based capture dicts are present + (a configuration that does not occur in production — the two + upstream paths are mutually exclusive by design), async QueryFeed + must populate exactly one of the two capture dicts with the + response headers. Precedence is intentionally unspecified. + """ client = self._create_minimal_connection() - options_capture = {"stale-options": "value"} - kwargs_capture = {"stale-kwargs": "value"} + options_capture: dict = {} + kwargs_capture: dict = {} expected_headers = {HttpHeaders.Continuation: "checkpoint-token-both", "x-ms-request-charge": "1.0"} async def _noop_set_session(*args, **kwargs): @@ -167,8 +178,11 @@ async def _fake_get(*args, **kwargs): self.assertEqual(docs, [{"id": "doc3"}]) self.assertEqual(response_headers, expected_headers) - self.assertEqual(options_capture, expected_headers) - self.assertEqual(kwargs_capture, {"stale-kwargs": "value"}) + populated = [d for d in (options_capture, kwargs_capture) if d == expected_headers] + self.assertEqual( + len(populated), 1, + f"expected exactly one capture dict populated; got options={options_capture!r}, kwargs={kwargs_capture!r}", + ) self.assertEqual(client.last_response_headers, expected_headers) async def test_execution_context_state_reset_on_partition_split_async(self): @@ -623,6 +637,94 @@ async def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + async def test_querfeed_populates_capture_dict_from_options_async(self): + """Async `__QueryFeed` must read the capture dict from `options` + and populate it from the underlying response headers, with no + test-side injection. Catches the `options`-vs-`kwargs` + extraction regression on the async path. + """ + from unittest.mock import patch as _patch + + # Build a CosmosClientConnection without running __init__; we + # only need the attributes that the no-query (read-feed) branch + # of async __QueryFeed touches. + conn = object.__new__(CosmosClientConnection) + conn.default_headers = {} + conn.last_response_headers = {} + conn.availability_strategy = None + conn.availability_strategy_max_concurrency = None + conn._global_endpoint_manager = MockGlobalEndpointManager() + conn._routing_map_provider = MagicMock(_collection_routing_map_by_item={}) + conn.session = None + conn.connection_policy = MagicMock() + conn._UpdateSessionIfRequired = MagicMock() + + capture_dict = {} + options = { + "_internal_response_headers_capture": capture_dict, + } + + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-querfeed-async"} + + request_obj_mock = MagicMock( + set_excluded_location_from_options=MagicMock(), + set_availability_strategy=MagicMock(), + headers={}, + operation_type="ReadFeed", + ) + + # Patch the heavy collaborators inside async __QueryFeed's + # no-query branch so we can drive it without a real pipeline. + with _patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", + return_value={}, + ), \ + _patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + new=AsyncMock(), + ), \ + _patch( + "azure.cosmos.aio._cosmos_client_connection_async._request_object.RequestObject", + return_value=request_obj_mock, + ), \ + _patch.object( + CosmosClientConnection, + "_CosmosClientConnection__Get", + new=AsyncMock(return_value=( + {"Documents": [{"id": "1"}], "_count": 1}, + canned_headers, + )), + ) as mock_get: + + # Invoke the name-mangled private async method directly. + result = await conn._CosmosClientConnection__QueryFeed( + "/dbs/db/colls/c/docs", + "docs", + "rid1", + lambda r: r["Documents"], + lambda _c, b: b, + None, # query=None -> read-feed branch -> __Get + options, + None, # partition_key_range_id + ) + + assert mock_get.await_count >= 1, "expected __Get to be awaited on the no-query path" + + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-querfeed-async", ( + f"capture dict was not populated by async __QueryFeed; got {capture_dict!r}. " + "This indicates async __QueryFeed is not reading " + "'_internal_response_headers_capture' from options." + ) + + # And the marker key must have been removed from options so it + # never leaks downstream into header construction or RequestObject. + assert "_internal_response_headers_capture" not in options, ( + "async __QueryFeed should pop the capture marker out of options" + ) + + # Sanity check: async no-query branch returns just the body list. + assert result == [{"id": "1"}] + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 54abca5680cd..deb6c8c7479e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -4,8 +4,6 @@ import os import unittest import uuid -from contextlib import contextmanager -from unittest.mock import patch import pytest @@ -42,18 +40,6 @@ def setUpClass(cls): cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) - @contextmanager - def _new_client_with_structured_full_pk_env(self, value: str): - use_multiple_write_locations = os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True" - with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": value}, clear=False): - with cosmos_client.CosmosClient( - self.host, - self.credential, - multiple_write_locations=use_multiple_write_locations, - ) as client: - database = client.get_database_client(self.TEST_DATABASE_ID) - created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - yield client, created_collection def test_first_and_last_slashes_trimmed_for_query_string(self): created_collection = self.created_db.create_container( @@ -508,45 +494,6 @@ def test_full_pk_continuation_emits_legacy_by_default(self): self.assertIsNotNone(token) self.assertIsNone(_decode_token(token)) - def test_full_pk_continuation_emits_structured_with_env_var(self): - with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - pager.next() - token = pager.continuation_token - - self.assertIsNotNone(token) - self.assertIsNotNone(_decode_token(token)) - - def test_full_pk_continuation_emits_structured_with_env_var_and_new_client(self): - with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": "true"}, clear=False): - with cosmos_client.CosmosClient( - self.host, - self.credential, - ) as client: - database = client.get_database_client(self.TEST_DATABASE_ID) - created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - pager.next() - token = pager.continuation_token - - self.assertIsNotNone(token) - self.assertIsNotNone(_decode_token(token)) - def test_full_pk_legacy_replay_resumes_same_page(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) @@ -569,306 +516,138 @@ def test_full_pk_legacy_replay_resumes_same_page(self): replay_second_page = list(replay_pager.next())[0] self.assertEqual(second_page['id'], replay_second_page['id']) - def test_full_pk_structured_replay_resumes_same_page(self): - with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - pager.next() - token = pager.continuation_token - second_page = list(pager.next())[0] - - self.assertIsNotNone(token) - self.assertIsNotNone(_decode_token(token)) - - replay_pager = query_iterable.by_page(token) - replay_second_page = list(replay_pager.next())[0] - self.assertEqual(second_page['id'], replay_second_page['id']) - - def test_full_pk_structured_replay_rejects_query_mismatch(self): - with self._new_client_with_structured_full_pk_env("on") as (_, created_collection): - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - source_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - source_pager = source_iterable.by_page() - source_pager.next() - token = source_pager.continuation_token - self.assertIsNotNone(_decode_token(token)) - - mismatched_query_iterable = created_collection.query_items( - query='SELECT VALUE c.id from c', - partition_key='pk', - max_item_count=1, - ) - with self.assertRaisesRegex(ValueError, 'query hash mismatch'): - mismatched_query_iterable.by_page(token).next() - - def test_full_pk_structured_replay_rejects_partition_key_mismatch(self): - with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk2', 'id': str(uuid.uuid4())}) - source_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - source_pager = source_iterable.by_page() - source_pager.next() - token = source_pager.continuation_token - self.assertIsNotNone(_decode_token(token)) - - mismatched_pk_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk2', - max_item_count=1, - ) - with self.assertRaisesRegex(ValueError, 'feed_range hash mismatch'): - mismatched_pk_iterable.by_page(token).next() - - def test_mixed_version_structured_token_replayed_by_legacy_mode(self): - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - new_mode_pager = new_mode_iterable.by_page() - new_mode_pager.next() - structured_token = new_mode_pager.continuation_token - self.assertIsNotNone(_decode_token(structured_token)) - - legacy_mode_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_mode_pager = legacy_mode_iterable.by_page(structured_token) - list(legacy_mode_pager.next()) - resumed_continuation = legacy_mode_pager.continuation_token - self.assertIsNotNone(resumed_continuation) - self.assertIsNone(_decode_token(resumed_continuation)) - - def test_mixed_version_legacy_token_replayed_by_structured_mode(self): + def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error(self): + """A non-HTTP error during the wrapped legacy-bridge POST must + propagate as-is. The 400-only fallback must NOT swallow runtime + exceptions or retry the page on them. + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - legacy_mode_pager.next() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + legacy_pager.next() + legacy_token = legacy_pager.continuation_token + # By the policy under review, full-PK always emits legacy. self.assertIsNone(_decode_token(legacy_token)) - with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - new_mode_pager = new_mode_iterable.by_page(legacy_token) - list(new_mode_pager.next()) - resumed_continuation = new_mode_pager.continuation_token - self.assertIsNotNone(resumed_continuation) - self.assertIsNotNone(_decode_token(resumed_continuation)) - - def test_full_pk_split_during_page_resets_retry_state(self): - pk_value = 'pk-' + str(uuid.uuid4()) - inserted_ids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())] - with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): - for doc_id in inserted_ids: - created_collection.upsert_item(body={'pk': pk_value, 'id': doc_id}) - query_iterable = created_collection.query_items( - query='SELECT * from c ORDER BY c.id', - partition_key=pk_value, - max_item_count=1, - ) - pager = query_iterable.by_page() - pager.next() - continuation_token = pager.continuation_token - pager.next() - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - injected_split = False - - def _split_once_post(*args, **kwargs): - nonlocal injected_split - req_headers = args[3] - if ( - not injected_split - and req_headers.get(http_constants.HttpHeaders.Continuation) - ): - injected_split = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.GONE, - sub_status=http_constants.SubStatusCodes.PARTITION_KEY_RANGE_GONE, - message='simulated split during full-pk page fetch', - ) - return original_post(*args, **kwargs) - - client_conn._CosmosClientConnection__Post = _split_once_post - try: - replay_pager = query_iterable.by_page(continuation_token) - replay_second_page = list(replay_pager.next())[0] - self.assertTrue(injected_split) - self.assertIn(replay_second_page['id'], inserted_ids) - self.assertIsNotNone(_decode_token(replay_pager.continuation_token)) - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post - def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error(self): - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + def _failing_post(*args, **kwargs): + raise RuntimeError("bridge-runtime-error") - legacy_mode_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_mode_pager = legacy_mode_iterable.by_page() - legacy_mode_pager.next() - legacy_token = legacy_mode_pager.continuation_token - self.assertIsNone(_decode_token(legacy_token)) - - with self._new_client_with_structured_full_pk_env("on") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - raise RuntimeError("bridge-runtime-error") - return original_post(*args, **kwargs) - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaisesRegex(RuntimeError, 'bridge-runtime-error'): - new_mode_iterable.by_page(legacy_token).next() - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaisesRegex(RuntimeError, 'bridge-runtime-error'): + legacy_iterable.by_page(legacy_token).next() + finally: + client_conn._CosmosClientConnection__Post = original_post def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error(self): + """Only 400 BadRequest triggers the wrapped-path fallback. Other + CosmosHttpResponseError statuses (e.g. 429) must propagate + without an old-shape retry being issued. + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - legacy_mode_pager.next() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + legacy_pager.next() + legacy_token = legacy_pager.continuation_token self.assertIsNone(_decode_token(legacy_token)) - with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + saw_legacy_fallback_headers = False + + def _failing_post(*args, **kwargs): + nonlocal saw_legacy_fallback_headers + req_headers = args[3] + # The fallback would re-issue with raw PartitionKey + raw + # legacy continuation + no PartitionKeyRangeID. If we ever + # see that shape on a 429-only path, the gating is wrong. + if ( + http_constants.HttpHeaders.PartitionKey in req_headers + and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token + and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers + ): + saw_legacy_fallback_headers = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, + message="throttled", ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - saw_legacy_fallback_headers = False - - def _failing_post(*args, **kwargs): - nonlocal saw_legacy_fallback_headers - req_headers = args[3] - if ( - http_constants.HttpHeaders.PartitionKey in req_headers - and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token - and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers - ): - saw_legacy_fallback_headers = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, - message="throttled", - ) - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaises(exceptions.CosmosHttpResponseError): - new_mode_iterable.by_page(legacy_token).next() - self.assertFalse(saw_legacy_fallback_headers) - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaises(exceptions.CosmosHttpResponseError): + legacy_iterable.by_page(legacy_token).next() + self.assertFalse(saw_legacy_fallback_headers) + finally: + client_conn._CosmosClientConnection__Post = original_post def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint(self): + """When the wrapped POST returns 400 and the one-shot legacy + fallback POST also fails, the original error is surfaced AND a + resumable checkpoint continuation must be stamped on + ``last_response_headers`` so the caller can retry. + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - legacy_mode_pager.next() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + legacy_pager.next() + legacy_token = legacy_pager.continuation_token self.assertIsNone(_decode_token(legacy_token)) - with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.BAD_REQUEST, - message="legacy bridge compatibility failure", - ) - raise RuntimeError("fallback-post-failed") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaisesRegex(RuntimeError, 'fallback-post-failed'): - new_mode_iterable.by_page(legacy_token).next() - self.assertEqual(post_call_count, 2) - continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) - self.assertIsNotNone(continuation) - self.assertIsNotNone(_decode_token(continuation)) - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + # First wrapped POST: 400 -> triggers one-shot fallback. + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.BAD_REQUEST, + message="legacy bridge compatibility failure", + ) + # Fallback POST: also fails, original error should surface. + raise RuntimeError("fallback-post-failed") + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with self.assertRaisesRegex(RuntimeError, 'fallback-post-failed'): + legacy_iterable.by_page(legacy_token).next() + self.assertEqual(post_call_count, 2) + # Checkpoint stamped: full-PK emits legacy, so the + # checkpoint continuation is the original legacy token. + continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + self.assertIsNotNone(continuation) + self.assertIsNone(_decode_token(continuation)) + finally: + client_conn._CosmosClientConnection__Post = original_post def test_cross_partition_query_with_none_partition_key(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index aad19680f546..0fa3552442d0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -4,9 +4,7 @@ import os import unittest import uuid -from contextlib import asynccontextmanager from asyncio import gather -from unittest.mock import patch import pytest @@ -56,17 +54,6 @@ async def asyncSetUp(self): async def asyncTearDown(self): await self.client.close() - @asynccontextmanager - async def _new_client_with_structured_full_pk_env(self, value: str): - with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": value}, clear=False): - async with CosmosClient( - self.host, - self.masterKey, - multiple_write_locations=self.use_multiple_write_locations, - ) as client: - database = client.get_database_client(self.TEST_DATABASE_ID) - created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - yield client, created_collection async def test_first_and_last_slashes_trimmed_for_query_string_async(self): created_collection = await self.created_db.create_container( @@ -504,46 +491,6 @@ async def test_full_pk_continuation_emits_legacy_by_default_async(self): assert token is not None assert _decode_token(token) is None - async def test_full_pk_continuation_emits_structured_with_env_var_async(self): - """Enabling the environment variable returns structured continuation tokens.""" - async with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - await pager.__anext__() - token = pager.continuation_token - - assert token is not None - assert _decode_token(token) is not None - - async def test_full_pk_continuation_emits_structured_with_env_var_and_new_client_async(self): - """The environment variable is read when the client is created.""" - with patch.dict(os.environ, {"AZURE_COSMOS_EMIT_STRUCTURED_CONTINUATION_PK": "true"}, clear=False): - async with CosmosClient( - self.host, - self.masterKey, - ) as client: - database = client.get_database_client(self.TEST_DATABASE_ID) - created_collection = database.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - await pager.__anext__() - token = pager.continuation_token - - assert token is not None - assert _decode_token(token) is not None async def test_full_pk_legacy_replay_resumes_same_page_async(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) @@ -567,305 +514,132 @@ async def test_full_pk_legacy_replay_resumes_same_page_async(self): replay_second_page = [item async for item in await replay_pager.__anext__()][0] assert second_page['id'] == replay_second_page['id'] - async def test_full_pk_structured_replay_resumes_same_page_async(self): - async with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - query_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - pager = query_iterable.by_page() - await pager.__anext__() - token = pager.continuation_token - second_page = [item async for item in await pager.__anext__()][0] - replay_pager = query_iterable.by_page(token) - replay_second_page = [item async for item in await replay_pager.__anext__()][0] - assert token is not None - assert _decode_token(token) is not None - assert second_page['id'] == replay_second_page['id'] - - async def test_full_pk_structured_replay_rejects_query_mismatch_async(self): - async with self._new_client_with_structured_full_pk_env("on") as (_, created_collection): - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - source_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - source_pager = source_iterable.by_page() - await source_pager.__anext__() - token = source_pager.continuation_token - assert _decode_token(token) is not None - - mismatched_query_iterable = created_collection.query_items( - query='SELECT VALUE c.id from c', - partition_key='pk', - max_item_count=1, - ) - with pytest.raises(ValueError, match='query hash mismatch'): - await mismatched_query_iterable.by_page(token).__anext__() - - async def test_full_pk_structured_replay_rejects_partition_key_mismatch_async(self): - async with self._new_client_with_structured_full_pk_env("1") as (_, created_collection): - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk2', 'id': str(uuid.uuid4())}) - source_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - source_pager = source_iterable.by_page() - await source_pager.__anext__() - token = source_pager.continuation_token - assert _decode_token(token) is not None - - mismatched_pk_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk2', - max_item_count=1, - ) - with pytest.raises(ValueError, match='feed_range hash mismatch'): - await mismatched_pk_iterable.by_page(token).__anext__() - - async def test_mixed_version_structured_token_replayed_by_legacy_mode_async(self): - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - async with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - new_mode_pager = new_mode_iterable.by_page() - await new_mode_pager.__anext__() - structured_token = new_mode_pager.continuation_token - assert _decode_token(structured_token) is not None - legacy_mode_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_mode_pager = legacy_mode_iterable.by_page(structured_token) - await legacy_mode_pager.__anext__() - resumed_continuation = legacy_mode_pager.continuation_token - assert resumed_continuation is not None - assert _decode_token(resumed_continuation) is None - - async def test_mixed_version_legacy_token_replayed_by_structured_mode_async(self): + async def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error_async(self): + """A non-HTTP error during the wrapped legacy-bridge POST must + propagate as-is on the async path. The 400-only fallback must + NOT swallow runtime exceptions or retry the page on them. + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - await legacy_mode_pager.__anext__() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + await legacy_pager.__anext__() + legacy_token = legacy_pager.continuation_token assert _decode_token(legacy_token) is None - async with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - new_mode_pager = new_mode_iterable.by_page(legacy_token) - await new_mode_pager.__anext__() - resumed_continuation = new_mode_pager.continuation_token - assert resumed_continuation is not None - assert _decode_token(resumed_continuation) is not None - - async def test_full_pk_split_during_page_resets_retry_state_async(self): - pk_value = 'pk-' + str(uuid.uuid4()) - inserted_ids = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())] - - async with self._new_client_with_structured_full_pk_env("yes") as (_, created_collection): - for doc_id in inserted_ids: - await created_collection.upsert_item(body={'pk': pk_value, 'id': doc_id}) - query_iterable = created_collection.query_items( - query='SELECT * from c ORDER BY c.id', - partition_key=pk_value, - max_item_count=1, - ) - pager = query_iterable.by_page() - await pager.__anext__() - continuation_token = pager.continuation_token - await pager.__anext__() - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - injected_split = False - - async def _split_once_post(*args, **kwargs): - nonlocal injected_split - req_headers = args[3] - if ( - not injected_split - and req_headers.get(http_constants.HttpHeaders.Continuation) - ): - injected_split = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.GONE, - sub_status=http_constants.SubStatusCodes.PARTITION_KEY_RANGE_GONE, - message='simulated split during full-pk page fetch async', - ) - return await original_post(*args, **kwargs) - - client_conn._CosmosClientConnection__Post = _split_once_post - try: - replay_pager = query_iterable.by_page(continuation_token) - replay_second_page = [item async for item in await replay_pager.__anext__()][0] - assert injected_split - assert replay_second_page['id'] in inserted_ids - assert _decode_token(replay_pager.continuation_token) is not None - finally: - client_conn._CosmosClientConnection__Post = original_post - - async def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error_async(self): - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post - legacy_mode_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_mode_pager = legacy_mode_iterable.by_page() - await legacy_mode_pager.__anext__() - legacy_token = legacy_mode_pager.continuation_token - assert _decode_token(legacy_token) is None + async def _failing_post(*args, **kwargs): + raise RuntimeError("bridge-runtime-error-async") - async with self._new_client_with_structured_full_pk_env("on") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - async def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - raise RuntimeError("bridge-runtime-error-async") - return await original_post(*args, **kwargs) - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(RuntimeError, match='bridge-runtime-error-async'): - await new_mode_iterable.by_page(legacy_token).__anext__() - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(RuntimeError, match='bridge-runtime-error-async'): + await legacy_iterable.by_page(legacy_token).__anext__() + finally: + client_conn._CosmosClientConnection__Post = original_post async def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error_async(self): + """Only 400 BadRequest triggers the wrapped-path fallback. Other + statuses (e.g. 429) must propagate without an old-shape retry + being issued, on the async path. + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - await legacy_mode_pager.__anext__() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + await legacy_pager.__anext__() + legacy_token = legacy_pager.continuation_token assert _decode_token(legacy_token) is None - async with self._new_client_with_structured_full_pk_env("true") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + saw_legacy_fallback_headers = False + + async def _failing_post(*args, **kwargs): + nonlocal saw_legacy_fallback_headers + req_headers = args[3] + if ( + http_constants.HttpHeaders.PartitionKey in req_headers + and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token + and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers + ): + saw_legacy_fallback_headers = True + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, + message="throttled", ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - saw_legacy_fallback_headers = False - - async def _failing_post(*args, **kwargs): - nonlocal saw_legacy_fallback_headers - req_headers = args[3] - if ( - http_constants.HttpHeaders.PartitionKey in req_headers - and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token - and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers - ): - saw_legacy_fallback_headers = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, - message="throttled", - ) - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(exceptions.CosmosHttpResponseError): - await new_mode_iterable.by_page(legacy_token).__anext__() - assert not saw_legacy_fallback_headers - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(exceptions.CosmosHttpResponseError): + await legacy_iterable.by_page(legacy_token).__anext__() + assert not saw_legacy_fallback_headers + finally: + client_conn._CosmosClientConnection__Post = original_post async def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint_async(self): + """When the wrapped POST returns 400 and the one-shot legacy + fallback POST also fails, the original error is surfaced AND a + resumable checkpoint continuation must be stamped on + ``last_response_headers`` (async path). + """ created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - legacy_mode_iterable = created_collection.query_items( + legacy_iterable = created_collection.query_items( query='SELECT * from c', partition_key='pk', max_item_count=1, ) - legacy_mode_pager = legacy_mode_iterable.by_page() - await legacy_mode_pager.__anext__() - legacy_token = legacy_mode_pager.continuation_token + legacy_pager = legacy_iterable.by_page() + await legacy_pager.__anext__() + legacy_token = legacy_pager.continuation_token assert _decode_token(legacy_token) is None - async with self._new_client_with_structured_full_pk_env("1") as (_, structured_collection): - new_mode_iterable = structured_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - client_conn = structured_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - async def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.BAD_REQUEST, - message="legacy bridge compatibility failure", - ) - raise RuntimeError("fallback-post-failed-async") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(RuntimeError, match='fallback-post-failed-async'): - await new_mode_iterable.by_page(legacy_token).__anext__() - assert post_call_count == 2 - continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) - assert continuation is not None - assert _decode_token(continuation) is not None - finally: - client_conn._CosmosClientConnection__Post = original_post + client_conn = created_collection.client_connection + original_post = client_conn._CosmosClientConnection__Post + post_call_count = 0 + + async def _failing_post(*args, **kwargs): + nonlocal post_call_count + post_call_count += 1 + if post_call_count == 1: + raise exceptions.CosmosHttpResponseError( + status_code=http_constants.StatusCodes.BAD_REQUEST, + message="legacy bridge compatibility failure", + ) + raise RuntimeError("fallback-post-failed-async") + + client_conn._CosmosClientConnection__Post = _failing_post + try: + with pytest.raises(RuntimeError, match='fallback-post-failed-async'): + await legacy_iterable.by_page(legacy_token).__anext__() + assert post_call_count == 2 + continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) + assert continuation is not None + assert _decode_token(continuation) is None + finally: + client_conn._CosmosClientConnection__Post = original_post + async def test_cross_partition_query_with_none_partition_key_async(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) From 8426ee419616c441cae1f4cb98736db236d4fdf6 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 11 May 2026 13:08:33 -0500 Subject: [PATCH 08/18] removing obsolete tests --- .../azure/cosmos/_cosmos_client_connection.py | 244 +++++++++-------- .../_routing/feed_range_continuation.py | 59 ++-- .../aio/_cosmos_client_connection_async.py | 256 +++++++++--------- .../test_feed_range_continuation_token.py | 105 +++++++ .../tests/test_partition_split_retry_unit.py | 142 ++++++++++ .../test_partition_split_retry_unit_async.py | 107 ++++++++ sdk/cosmos/azure-cosmos/tests/test_query.py | 132 --------- .../azure-cosmos/tests/test_query_async.py | 125 --------- 8 files changed, 655 insertions(+), 515 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 24917acdf74d..790194e72ba7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -79,11 +79,9 @@ _count_page_items_from_partial_result, _decode_token, _derive_initial_feedranges, - _hash_feed_range, - _hash_query_spec, _increment_explode_iterations_or_raise, _normalize_max_item_count, - _should_attempt_legacy_bridge_fallback, + _should_bridge_legacy_continuation, _update_no_progress_page_count, _validate_token_identity, _write_query_outbound_continuation, @@ -3317,7 +3315,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: feed_range_epk = None container_properties = kwargs.pop("container_properties", None) is_full_pk_structured_scope = False - legacy_partition_key_header = req_headers.get(http_constants.HttpHeaders.PartitionKey) if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() @@ -3353,18 +3350,60 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # ``None`` means start from the beginning of the requested # feed range. page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) - query_hash = _hash_query_spec(query) - feedrange_hash = _hash_feed_range(feed_range_epk) + # Identity hashes are only needed when an inbound v=1 token must be + # validated, or when the outbound continuation will be a v=1 envelope. + # Compute lazily so the single-partition legacy outbound path pays no + # hashing cost. The token writers handle ``None`` by computing on demand. + query_hash: Optional[str] = None + feedrange_hash: Optional[str] = None + # Single shared copy of options for routing-map lookups in this call. + # ``get_overlapping_ranges`` does not mutate options; copying once + # avoids per-iteration ``dict(options)`` allocations. + routing_options = dict(options) inbound_serialized_continuation = options.get("continuation") inbound_token_payload = _decode_token(inbound_serialized_continuation) legacy_bridge_in_use = False - legacy_fallback_attempted = False - if inbound_serialized_continuation and inbound_token_payload is None: - if is_full_pk_structured_scope: - _LOGGER.warning( - "Full-PK query continuation token is in legacy format; " - "bridging it into structured pagination state for resume." + # Cache for the input scope's single-partition classification. + # We compute it at most once per __QueryFeed call so inbound + # bridge-detection, mid-page checkpoint, and end-of-page outbound + # writer all agree even if the PKRANGE cache refreshes mid-call. + cached_is_single_partition: Optional[bool] = None + + def _is_input_scope_single_partition() -> bool: + """Return True when the caller input range currently maps to one physical partition. + + Result is cached for the duration of this __QueryFeed call. + """ + nonlocal cached_is_single_partition + if cached_is_single_partition is None: + scope_overlaps = self._routing_map_provider.get_overlapping_ranges( + resource_id, [feed_range_epk], routing_options + ) + cached_is_single_partition = ( + len(_derive_initial_feedranges(feed_range_epk, scope_overlaps)) == 1 ) + return cached_is_single_partition + + if inbound_serialized_continuation and inbound_token_payload is None: + scope_is_single_partition = False + if not is_full_pk_structured_scope: + scope_is_single_partition = _is_input_scope_single_partition() + if _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + scope_is_single_partition, + ): + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + else: + _LOGGER.warning( + "Feed-range query continuation token is in legacy format and the input scope " + "currently maps to one physical partition; honoring legacy continuation for resume." + ) legacy_bridge_in_use = True else: _LOGGER.warning( @@ -3395,7 +3434,7 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # each overlap into a feedrange (intersection of that partition # and the input feed_range). first_overlaps = self._routing_map_provider.get_overlapping_ranges( - resource_id, [feed_range_epk], dict(options) + resource_id, [feed_range_epk], routing_options ) all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) if not all_feedranges: @@ -3429,6 +3468,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, + (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), query_hash, feedrange_hash, ) @@ -3446,120 +3486,91 @@ def _checkpoint_and_reraise(error: Exception) -> None: if head_feedrange is None: break - # Look up the live routing map for the current feedrange. - # Doing this every iteration is what makes the token - # split-safe. - overlapping = self._routing_map_provider.get_overlapping_ranges( - resource_id, [head_feedrange], dict(options) - ) - overlapping, partition_scope = _build_scope_from_overlaps( - overlapping, head_feedrange - ) - - # If routing returns multiple overlaps, the head sub-range now spans a split - # that occurred after the token was created. Re-slice and re-resolve until - # each head maps to one partition. See - # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. - explode_iterations = 0 - while pagination_state.explode_on_multi_overlap(overlapping): - explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) - head_feedrange = pagination_state.head_range - if head_feedrange is None: - break + # Wrap all mid-page work that can raise (routing lookups, + # scope build/explode, backend POST, and result merge) so + # we always stamp a resumable checkpoint into + # last_response_headers[Continuation] before re-raising. + # Post-result accounting below is pure local bookkeeping + # and is intentionally left outside this try. + try: + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. overlapping = self._routing_map_provider.get_overlapping_ranges( - resource_id, [head_feedrange], dict(options) + resource_id, [head_feedrange], routing_options ) overlapping, partition_scope = _build_scope_from_overlaps( overlapping, head_feedrange ) - head_feedrange = pagination_state.head_range - if head_feedrange is None: - continue - - # Populate request headers for this single backend POST. - # The shared helper handles partition routing (PKR id + - # optional EPK filter), page-size cap, and continuation - # set/clear so the same rules apply to sync and async. - _apply_feedrange_request_headers( - req_headers, - overlapping, - partition_scope, - head_feedrange, - pagination_state.page_size_hint, - pagination_state.head_bc, - ) - # Use the session token for this specific partition so we don't - # send a compound token covering all partitions. - base.set_session_token_header( - self, req_headers, path, request_params, options, overlapping[0]["id"] - ) + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = self._routing_map_provider.get_overlapping_ranges( + resource_id, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + base.set_session_token_header( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) - try: backend_query_result, backend_response_headers = self.__Post( path, request_params, query, req_headers, **kwargs ) - except exceptions.CosmosHttpResponseError as post_error: - if ( - legacy_bridge_in_use - and not legacy_fallback_attempted - and _should_attempt_legacy_bridge_fallback(post_error) - ): - legacy_fallback_attempted = True - req_headers.pop(http_constants.HttpHeaders.PartitionKeyRangeID, None) - req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) - req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) - req_headers.pop(http_constants.HttpHeaders.ReadFeedKeyType, None) - if legacy_partition_key_header is not None: - req_headers[http_constants.HttpHeaders.PartitionKey] = legacy_partition_key_header - req_headers[http_constants.HttpHeaders.Continuation] = inbound_serialized_continuation - base.set_session_token_header( - self, req_headers, path, request_params, options, partition_key_range_id + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, ) - try: - backend_query_result, backend_response_headers = self.__Post( - path, request_params, query, req_headers, **kwargs - ) - except Exception as fallback_error: # pylint: disable=broad-exception-caught - _checkpoint_and_reraise(fallback_error) - self.last_response_headers = backend_response_headers - if internal_headers_capture is not None: - _capture_internal_headers(backend_response_headers) - self._UpdateSessionIfRequired( - req_headers, backend_query_result, backend_response_headers - ) - if response_headers_list is not None: - response_headers_list.append(backend_response_headers.copy()) - if response_hook: - response_hook(backend_response_headers, backend_query_result) - return __GetBodiesFromQueryResult(backend_query_result), backend_response_headers - _checkpoint_and_reraise(post_error) - except Exception as post_error: # pylint: disable=broad-exception-caught - _checkpoint_and_reraise(post_error) - feedrange_response_headers = backend_response_headers - self.last_response_headers = feedrange_response_headers - if internal_headers_capture is not None: - _capture_internal_headers(backend_response_headers) - self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) - - # Merge results, falling back to a plain extend if the - # aggregating merge raises (it can on aggregated queries - # during splits). - try: - results = base._merge_query_results(results, backend_query_result, query) - except ValueError as merge_error: - base._raise_query_merge_value_error(merge_error) - except (TypeError, KeyError) as merge_error: - _LOGGER.warning( - "Falling back to non-aggregate merge after aggregate merge failure: %s", - merge_error, - ) - results_docs = results.get("Documents") if results else None - partial_docs = backend_query_result.get("Documents") if backend_query_result else None - if isinstance(results_docs, list) and isinstance(partial_docs, list): - results_docs.extend(partial_docs) - elif backend_query_result: - results = backend_query_result + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + except Exception as mid_page_error: # pylint: disable=broad-exception-caught + _checkpoint_and_reraise(mid_page_error) previous_feedrange = pagination_state.head_range previous_backend_continuation = pagination_state.head_bc @@ -3612,6 +3623,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, + (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), query_hash, feedrange_hash, ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py index ada117b41d68..a778d2cdd58c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py @@ -306,6 +306,28 @@ def _validate_token_identity( ) +def _should_bridge_legacy_continuation( + inbound_serialized_continuation: Optional[str], + inbound_token_payload: Optional[dict], + is_full_pk_structured_scope: bool, + is_single_partition_scope: bool, +) -> bool: + """Whether to bridge an inbound legacy continuation into pagination state. + + We bridge only when the inbound continuation exists, did not decode as + structured ``v=1`` (legacy/opaque token), and the current request scope can + be represented safely by a single legacy continuation slot: + + * full-PK structured scope (always single partition), or + * non-full-PK scope that currently maps to one physical partition. + """ + return bool( + inbound_serialized_continuation + and inbound_token_payload is None + and (is_full_pk_structured_scope or is_single_partition_scope) + ) + + def _extract_resume_queue( inbound: dict, ) -> List[Tuple[routing_range.Range, Optional[str]]]: @@ -611,14 +633,17 @@ def _write_query_outbound_continuation( query: Any, feed_range_epk: routing_range.Range, is_full_pk_structured_scope: bool, - query_hash: str, - feedrange_hash: str, + emit_legacy_for_single_partition: bool, + query_hash: Optional[str], + feedrange_hash: Optional[str], ) -> None: """Write outbound continuation for feed-range pagination. Full-PK queries always emit the legacy single-string continuation so persisted bookmarks remain readable by older SDK versions. - Feed-range and prefix queries always emit the structured envelope. + Feed-range/prefix queries emit legacy continuation when the caller's + input scope currently maps to a single physical partition; otherwise + they emit the structured envelope. :param last_response_headers: Response headers to mutate. :type last_response_headers: MutableMapping[str, Any] @@ -632,14 +657,19 @@ def _write_query_outbound_continuation( :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range :param is_full_pk_structured_scope: Whether request scope is full-PK on structured path. :type is_full_pk_structured_scope: bool - :param query_hash: Precomputed query hash for outbound token identity. - :type query_hash: str - :param feedrange_hash: Precomputed feed range hash for outbound token identity. - :type feedrange_hash: str + :param emit_legacy_for_single_partition: Whether non-full-PK scope currently maps to a + single physical partition and can safely emit legacy continuation. + :type emit_legacy_for_single_partition: bool + :param query_hash: Precomputed query hash for outbound token identity, or ``None`` to + compute lazily inside the structured-token writer when needed. + :type query_hash: Optional[str] + :param feedrange_hash: Precomputed feed range hash for outbound token identity, or + ``None`` to compute lazily inside the structured-token writer when needed. + :type feedrange_hash: Optional[str] :returns: None. Mutates ``last_response_headers`` in place. :rtype: None """ - if is_full_pk_structured_scope: + if is_full_pk_structured_scope or emit_legacy_for_single_partition: legacy_outbound = pagination_state.head_bc if legacy_outbound is None: last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) @@ -656,19 +686,6 @@ def _write_query_outbound_continuation( ) -def _should_attempt_legacy_bridge_fallback(error: Any) -> bool: - """Return whether a compatibility fallback should be attempted. - - Compatibility fallback is restricted to legacy-token bridge failures - that surface as ``400 BadRequest``. - - :param error: Exception raised by backend request execution. - :type error: Any - :returns: ``True`` when the error is a ``400 BadRequest`` compatibility failure. - :rtype: bool - """ - return getattr(error, "status_code", None) == http_constants.StatusCodes.BAD_REQUEST - def _build_outbound_token( resource_id: str, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index dc647335b382..7f12ddbc3d7a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -63,11 +63,9 @@ _count_page_items_from_partial_result, _decode_token, _derive_initial_feedranges, - _hash_feed_range, - _hash_query_spec, _increment_explode_iterations_or_raise, _normalize_max_item_count, - _should_attempt_legacy_bridge_fallback, + _should_bridge_legacy_continuation, _update_no_progress_page_count, _validate_token_identity, _write_query_outbound_continuation, @@ -3118,7 +3116,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # Check if the overlapping ranges can be populated feed_range_epk = None is_full_pk_structured_scope = False - legacy_partition_key_header = req_headers.get(http_constants.HttpHeaders.PartitionKey) if "feed_range" in kwargs: feed_range = kwargs.pop("feed_range") feed_range_epk = FeedRangeInternalEpk.from_json(feed_range).get_normalized_range() @@ -3150,18 +3147,60 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # ``None`` means start from the beginning of the requested # feed range. page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) - query_hash = _hash_query_spec(query) - feedrange_hash = _hash_feed_range(feed_range_epk) + # Identity hashes are only needed when an inbound v=1 token must be + # validated, or when the outbound continuation will be a v=1 envelope. + # Compute lazily so the single-partition legacy outbound path pays no + # hashing cost. The token writers handle ``None`` by computing on demand. + query_hash: Optional[str] = None + feedrange_hash: Optional[str] = None + # Single shared copy of options for routing-map lookups in this call. + # ``get_overlapping_ranges`` does not mutate options; copying once + # avoids per-iteration ``dict(options)`` allocations. + routing_options = dict(options) inbound_serialized_continuation = options.get("continuation") inbound_token_payload = _decode_token(inbound_serialized_continuation) legacy_bridge_in_use = False - legacy_fallback_attempted = False - if inbound_serialized_continuation and inbound_token_payload is None: - if is_full_pk_structured_scope: - _LOGGER.warning( - "Full-PK query continuation token is in legacy format; " - "bridging it into structured pagination state for resume." + # Cache for the input scope's single-partition classification. + # We compute it at most once per __QueryFeed call so inbound + # bridge-detection, mid-page checkpoint, and end-of-page outbound + # writer all agree even if the PKRANGE cache refreshes mid-call. + cached_is_single_partition: Optional[bool] = None + + async def _is_input_scope_single_partition() -> bool: + """Return True when the caller input range currently maps to one physical partition. + + Result is cached for the duration of this __QueryFeed call. + """ + nonlocal cached_is_single_partition + if cached_is_single_partition is None: + scope_overlaps = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [feed_range_epk], routing_options + ) + cached_is_single_partition = ( + len(_derive_initial_feedranges(feed_range_epk, scope_overlaps)) == 1 ) + return cached_is_single_partition + + if inbound_serialized_continuation and inbound_token_payload is None: + scope_is_single_partition = False + if not is_full_pk_structured_scope: + scope_is_single_partition = await _is_input_scope_single_partition() + if _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + scope_is_single_partition, + ): + if is_full_pk_structured_scope: + _LOGGER.warning( + "Full-PK query continuation token is in legacy format; " + "bridging it into structured pagination state for resume." + ) + else: + _LOGGER.warning( + "Feed-range query continuation token is in legacy format and the input scope " + "currently maps to one physical partition; honoring legacy continuation for resume." + ) legacy_bridge_in_use = True else: _LOGGER.warning( @@ -3188,7 +3227,7 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: ) else: first_overlaps = await self._routing_map_provider.get_overlapping_ranges( - id_, [feed_range_epk], dict(options) + resource_id_str, [feed_range_epk], routing_options ) all_feedranges = _derive_initial_feedranges(feed_range_epk, first_overlaps) if not all_feedranges: @@ -3210,11 +3249,14 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: feedrange_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() consecutive_no_progress_pages = 0 - def _checkpoint_and_reraise(error: Exception) -> None: + async def _checkpoint_and_reraise(error: Exception) -> None: # Intentionally broad: stamp the latest resumable checkpoint # for any mid-page failure, then re-raise the original error. self.last_response_headers = feedrange_response_headers try: + single_partition_scope_for_outbound = ( + (not is_full_pk_structured_scope) and (await _is_input_scope_single_partition()) + ) _write_query_outbound_continuation( feedrange_response_headers, pagination_state, @@ -3222,6 +3264,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, + single_partition_scope_for_outbound, query_hash, feedrange_hash, ) @@ -3239,56 +3282,62 @@ def _checkpoint_and_reraise(error: Exception) -> None: if head_feedrange is None: break - # Look up the live routing map for the current feedrange. - # Doing this every iteration is what makes the token - # split-safe. - overlapping = await self._routing_map_provider.get_overlapping_ranges( - id_, [head_feedrange], dict(options) - ) - overlapping, partition_scope = _build_scope_from_overlaps( - overlapping, head_feedrange - ) - - # If routing returns multiple overlaps, the head sub-range now spans a split - # that occurred after the token was created. Re-slice and re-resolve until - # each head maps to one partition. See - # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. - explode_iterations = 0 - while pagination_state.explode_on_multi_overlap(overlapping): - explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) - head_feedrange = pagination_state.head_range - if head_feedrange is None: - break + # Wrap all mid-page work that can raise (routing lookups, + # scope build/explode, backend POST, and result merge) so + # we always stamp a resumable checkpoint into + # last_response_headers[Continuation] before re-raising. + # Post-result accounting below is pure local bookkeeping + # and is intentionally left outside this try. + try: + # Look up the live routing map for the current feedrange. + # Doing this every iteration is what makes the token + # split-safe. overlapping = await self._routing_map_provider.get_overlapping_ranges( - id_, [head_feedrange], dict(options) + resource_id_str, [head_feedrange], routing_options ) overlapping, partition_scope = _build_scope_from_overlaps( overlapping, head_feedrange ) - head_feedrange = pagination_state.head_range - if head_feedrange is None: - continue - - # Populate request headers for this single backend POST. - # The shared helper handles partition routing (PKR id + - # optional EPK filter), page-size cap, and continuation - # set/clear so the same rules apply to sync and async. - _apply_feedrange_request_headers( - req_headers, - overlapping, - partition_scope, - head_feedrange, - pagination_state.page_size_hint, - pagination_state.head_bc, - ) - # Use the session token for this specific partition so we don't - # send a compound token covering all partitions. - await base.set_session_token_header_async( - self, req_headers, path, request_params, options, overlapping[0]["id"] - ) + # If routing returns multiple overlaps, the head sub-range now spans a split + # that occurred after the token was created. Re-slice and re-resolve until + # each head maps to one partition. See + # ``_FeedRangePaginationState.explode_on_multi_overlap`` for details. + explode_iterations = 0 + while pagination_state.explode_on_multi_overlap(overlapping): + explode_iterations = _increment_explode_iterations_or_raise(explode_iterations) + head_feedrange = pagination_state.head_range + if head_feedrange is None: + break + overlapping = await self._routing_map_provider.get_overlapping_ranges( + resource_id_str, [head_feedrange], routing_options + ) + overlapping, partition_scope = _build_scope_from_overlaps( + overlapping, head_feedrange + ) + + head_feedrange = pagination_state.head_range + if head_feedrange is None: + continue + + # Populate request headers for this single backend POST. + # The shared helper handles partition routing (PKR id + + # optional EPK filter), page-size cap, and continuation + # set/clear so the same rules apply to sync and async. + _apply_feedrange_request_headers( + req_headers, + overlapping, + partition_scope, + head_feedrange, + pagination_state.page_size_hint, + pagination_state.head_bc, + ) + # Use the session token for this specific partition so we don't + # send a compound token covering all partitions. + await base.set_session_token_header_async( + self, req_headers, path, request_params, options, overlapping[0]["id"] + ) - try: backend_query_result, backend_response_headers = await self.__Post( path, request_params, @@ -3296,71 +3345,32 @@ def _checkpoint_and_reraise(error: Exception) -> None: req_headers, **kwargs ) - except exceptions.CosmosHttpResponseError as post_error: - if ( - legacy_bridge_in_use - and not legacy_fallback_attempted - and _should_attempt_legacy_bridge_fallback(post_error) - ): - legacy_fallback_attempted = True - req_headers.pop(http_constants.HttpHeaders.PartitionKeyRangeID, None) - req_headers.pop(http_constants.HttpHeaders.StartEpkString, None) - req_headers.pop(http_constants.HttpHeaders.EndEpkString, None) - req_headers.pop(http_constants.HttpHeaders.ReadFeedKeyType, None) - if legacy_partition_key_header is not None: - req_headers[http_constants.HttpHeaders.PartitionKey] = legacy_partition_key_header - req_headers[http_constants.HttpHeaders.Continuation] = inbound_serialized_continuation - await base.set_session_token_header_async( - self, req_headers, path, request_params, options, partition_key_range_id + feedrange_response_headers = backend_response_headers + self.last_response_headers = feedrange_response_headers + if internal_headers_capture is not None: + _capture_internal_headers(backend_response_headers) + self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) + + # Merge results, falling back to a plain extend if the + # aggregating merge raises (it can on aggregated queries + # during splits). + try: + results = base._merge_query_results(results, backend_query_result, query) + except ValueError as merge_error: + base._raise_query_merge_value_error(merge_error) + except (TypeError, KeyError) as merge_error: + _LOGGER.warning( + "Falling back to non-aggregate merge after aggregate merge failure: %s", + merge_error, ) - try: - backend_query_result, backend_response_headers = await self.__Post( - path, - request_params, - query, - req_headers, - **kwargs - ) - except Exception as fallback_error: # pylint: disable=broad-exception-caught - _checkpoint_and_reraise(fallback_error) - self.last_response_headers = backend_response_headers - if internal_headers_capture is not None: - _capture_internal_headers(backend_response_headers) - self._UpdateSessionIfRequired( - req_headers, backend_query_result, backend_response_headers - ) - if response_headers_list is not None: - response_headers_list.append(backend_response_headers.copy()) - if response_hook: - response_hook(backend_response_headers, backend_query_result) - return __GetBodiesFromQueryResult(backend_query_result) - _checkpoint_and_reraise(post_error) - except Exception as post_error: # pylint: disable=broad-exception-caught - _checkpoint_and_reraise(post_error) - feedrange_response_headers = backend_response_headers - self.last_response_headers = feedrange_response_headers - if internal_headers_capture is not None: - _capture_internal_headers(backend_response_headers) - self._UpdateSessionIfRequired(req_headers, backend_query_result, backend_response_headers) - - # Merge results, falling back to a plain extend if the - # aggregating merge raises (it can on aggregated queries - # during splits). - try: - results = base._merge_query_results(results, backend_query_result, query) - except ValueError as merge_error: - base._raise_query_merge_value_error(merge_error) - except (TypeError, KeyError) as merge_error: - _LOGGER.warning( - "Falling back to non-aggregate merge after aggregate merge failure: %s", - merge_error, - ) - results_docs = results.get("Documents") if results else None - partial_docs = backend_query_result.get("Documents") if backend_query_result else None - if isinstance(results_docs, list) and isinstance(partial_docs, list): - results_docs.extend(partial_docs) - elif backend_query_result: - results = backend_query_result + results_docs = results.get("Documents") if results else None + partial_docs = backend_query_result.get("Documents") if backend_query_result else None + if isinstance(results_docs, list) and isinstance(partial_docs, list): + results_docs.extend(partial_docs) + elif backend_query_result: + results = backend_query_result + except Exception as mid_page_error: # pylint: disable=broad-exception-caught + await _checkpoint_and_reraise(mid_page_error) previous_feedrange = pagination_state.head_range previous_backend_continuation = pagination_state.head_bc @@ -3406,6 +3416,9 @@ def _checkpoint_and_reraise(error: Exception) -> None: # Pagination loop is done — write the final outbound # continuation (or clear the header if the queue is fully # drained) so the caller's ``by_page`` loop terminates. + single_partition_scope_for_outbound = ( + (not is_full_pk_structured_scope) and (await _is_input_scope_single_partition()) + ) _write_query_outbound_continuation( feedrange_response_headers, pagination_state, @@ -3413,6 +3426,7 @@ def _checkpoint_and_reraise(error: Exception) -> None: query, feed_range_epk, is_full_pk_structured_scope, + single_partition_scope_for_outbound, query_hash, feedrange_hash, ) diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index fa965ee34931..8be26c5f5e09 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -54,7 +54,9 @@ _hash_query_spec, _stable_hash_128, _normalize_max_item_count, + _should_bridge_legacy_continuation, _increment_explode_iterations_or_raise, + _write_query_outbound_continuation, _update_no_progress_page_count, _validate_token_identity, _FIELD_BACKEND_CONTINUATION, @@ -1525,3 +1527,106 @@ def test_drained_state_clears_continuation_header(self): assert http_constants.HttpHeaders.Continuation not in headers + +class TestWriteQueryOutboundContinuation: + """Outbound continuation format selection should match request scope policy.""" + + def test_full_pk_scope_always_emits_legacy(self): + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + _HEAD_FEEDRANGE, + _BACKEND_CONT, + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=True, + emit_legacy_for_single_partition=False, + query_hash=_hash_query_spec(_QUERY), + feedrange_hash=_hash_feed_range(_FEED_RANGE), + ) + + assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT + + def test_non_full_pk_single_partition_scope_emits_legacy(self): + state = _FeedRangePaginationState.from_single_feedrange_with_continuation( + _HEAD_FEEDRANGE, + _BACKEND_CONT, + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=False, + emit_legacy_for_single_partition=True, + query_hash=_hash_query_spec(_QUERY), + feedrange_hash=_hash_feed_range(_FEED_RANGE), + ) + + assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT + + def test_non_full_pk_multi_partition_scope_emits_structured(self): + state = _FeedRangePaginationState( + [(_HEAD_FEEDRANGE, _BACKEND_CONT), (_REMAINING_FEEDRANGE, None)], + page_size_hint=5, + ) + headers: dict = {} + + _write_query_outbound_continuation( + headers, + state, + _RID, + _QUERY, + _FEED_RANGE, + is_full_pk_structured_scope=False, + emit_legacy_for_single_partition=False, + query_hash=_hash_query_spec(_QUERY), + feedrange_hash=_hash_feed_range(_FEED_RANGE), + ) + + decoded = _decode_token(headers[http_constants.HttpHeaders.Continuation]) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + + +class TestLegacyBridgeDecision: + """Legacy inbound continuation is bridged only when scope is safely single-partition.""" + + @pytest.mark.parametrize( + "inbound_serialized_continuation,inbound_token_payload,is_full_pk_structured_scope," + "is_single_partition_scope,expected", + [ + (None, None, False, True, False), + ("", None, False, True, False), + ("legacy", {"v": 1}, False, True, False), + ("legacy", None, True, False, True), + ("legacy", None, False, True, True), + ("legacy", None, False, False, False), + ], + ) + def test_should_bridge_legacy_continuation_policy( + self, + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + is_single_partition_scope, + expected, + ): + assert _should_bridge_legacy_continuation( + inbound_serialized_continuation, + inbound_token_payload, + is_full_pk_structured_scope, + is_single_partition_scope, + ) is expected + + diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index 935fce63c184..ff2a73103bef 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -15,6 +15,7 @@ from azure.cosmos import exceptions from azure.cosmos._cosmos_client_connection import CosmosClientConnection from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext +from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders @@ -730,6 +731,147 @@ def test_querfeed_populates_capture_dict_from_options(self): assert result == [{"id": "1"}] assert headers is canned_headers + def test_queryfeed_feed_range_legacy_inbound_single_partition_honors_and_emits_legacy(self): + """Legacy inbound continuation is honored when feed_range currently maps to one partition.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + + def overlap_side_effect(_rid, ranges, _opts): + _ = ranges + return single_overlap + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + seen_request_continuations = [] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "legacy-next-token"} + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == ["legacy-inbound-token"] + assert headers.get(HttpHeaders.Continuation) == "legacy-next-token" + + def test_queryfeed_feed_range_legacy_inbound_multi_partition_restarts_and_emits_v1(self): + """Legacy inbound continuation is ignored when scope is multi-partition; outbound becomes v=1.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "0", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "1", "minInclusive": "7F", "maxExclusive": "FF"} + + def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + seen_request_continuations = [] + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "child-legacy-token"} + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == [None] + outbound = headers.get(HttpHeaders.Continuation) + decoded = _decode_token(outbound) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + + def test_queryfeed_feed_range_routing_lookup_failure_stamps_checkpoint(self): + """A failure inside the mid-page routing-map lookup must stamp a resumable + checkpoint into ``last_response_headers[Continuation]`` before re-raising, + not just failures from the backend POST. + """ + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + routing_call_count = {"n": 0} + + def overlap_side_effect(_rid, _ranges, _opts): + routing_call_count["n"] += 1 + # First call (legacy bridge classification) succeeds; the mid-page + # iteration call fails so we exercise the widened try block. + if routing_call_count["n"] >= 2: + raise RuntimeError("routing-map-down") + return single_overlap + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post") as post_mock: + with pytest.raises(RuntimeError, match="routing-map-down"): + client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + post_mock.assert_not_called() + + # Checkpoint must be present so the caller can resume on retry. + # Single-partition scope => legacy-format checkpoint (the original inbound token). + continuation = client.last_response_headers.get(HttpHeaders.Continuation) + assert continuation == "legacy-inbound-token" + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index 2c3356568965..42d543e88fe1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -17,6 +17,7 @@ from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext +from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token # tracemalloc is not available in PyPy, so we import conditionally try: @@ -725,6 +726,112 @@ async def test_querfeed_populates_capture_dict_from_options_async(self): # Sanity check: async no-query branch returns just the body list. assert result == [{"id": "1"}] + async def test_queryfeed_feed_range_legacy_inbound_single_partition_honors_and_emits_legacy_async(self): + """Async: legacy inbound continuation is honored when feed_range maps to one partition.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + + async def overlap_side_effect(_rid, ranges, _opts): + _ = ranges + return single_overlap + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + seen_request_continuations = [] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "legacy-next-token"} + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == ["legacy-inbound-token"] + assert headers.get(HttpHeaders.Continuation) == "legacy-next-token" + + async def test_queryfeed_feed_range_legacy_inbound_multi_partition_restarts_and_emits_v1_async(self): + """Async: legacy inbound continuation is ignored when scope is multi-partition; outbound becomes v=1.""" + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "0", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "1", "minInclusive": "7F", "maxExclusive": "FF"} + + async def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + seen_request_continuations = [] + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + return {"Documents": [{"id": "doc-1"}]}, {HttpHeaders.Continuation: "child-legacy-token"} + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + docs, headers = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + + assert docs == [{"id": "doc-1"}] + assert seen_request_continuations == [None] + outbound = headers.get(HttpHeaders.Continuation) + decoded = _decode_token(outbound) + assert decoded is not None + assert decoded[_FIELD_VERSION] == _TOKEN_VERSION + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index deb6c8c7479e..84ac3116efab 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -516,138 +516,6 @@ def test_full_pk_legacy_replay_resumes_same_page(self): replay_second_page = list(replay_pager.next())[0] self.assertEqual(second_page['id'], replay_second_page['id']) - def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error(self): - """A non-HTTP error during the wrapped legacy-bridge POST must - propagate as-is. The 400-only fallback must NOT swallow runtime - exceptions or retry the page on them. - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - legacy_pager.next() - legacy_token = legacy_pager.continuation_token - # By the policy under review, full-PK always emits legacy. - self.assertIsNone(_decode_token(legacy_token)) - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - - def _failing_post(*args, **kwargs): - raise RuntimeError("bridge-runtime-error") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaisesRegex(RuntimeError, 'bridge-runtime-error'): - legacy_iterable.by_page(legacy_token).next() - finally: - client_conn._CosmosClientConnection__Post = original_post - - def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error(self): - """Only 400 BadRequest triggers the wrapped-path fallback. Other - CosmosHttpResponseError statuses (e.g. 429) must propagate - without an old-shape retry being issued. - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - legacy_pager.next() - legacy_token = legacy_pager.continuation_token - self.assertIsNone(_decode_token(legacy_token)) - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - saw_legacy_fallback_headers = False - - def _failing_post(*args, **kwargs): - nonlocal saw_legacy_fallback_headers - req_headers = args[3] - # The fallback would re-issue with raw PartitionKey + raw - # legacy continuation + no PartitionKeyRangeID. If we ever - # see that shape on a 429-only path, the gating is wrong. - if ( - http_constants.HttpHeaders.PartitionKey in req_headers - and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token - and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers - ): - saw_legacy_fallback_headers = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, - message="throttled", - ) - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaises(exceptions.CosmosHttpResponseError): - legacy_iterable.by_page(legacy_token).next() - self.assertFalse(saw_legacy_fallback_headers) - finally: - client_conn._CosmosClientConnection__Post = original_post - - def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint(self): - """When the wrapped POST returns 400 and the one-shot legacy - fallback POST also fails, the original error is surfaced AND a - resumable checkpoint continuation must be stamped on - ``last_response_headers`` so the caller can retry. - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - legacy_pager.next() - legacy_token = legacy_pager.continuation_token - self.assertIsNone(_decode_token(legacy_token)) - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - # First wrapped POST: 400 -> triggers one-shot fallback. - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.BAD_REQUEST, - message="legacy bridge compatibility failure", - ) - # Fallback POST: also fails, original error should surface. - raise RuntimeError("fallback-post-failed") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with self.assertRaisesRegex(RuntimeError, 'fallback-post-failed'): - legacy_iterable.by_page(legacy_token).next() - self.assertEqual(post_call_count, 2) - # Checkpoint stamped: full-PK emits legacy, so the - # checkpoint continuation is the original legacy token. - continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) - self.assertIsNotNone(continuation) - self.assertIsNone(_decode_token(continuation)) - finally: - client_conn._CosmosClientConnection__Post = original_post def test_cross_partition_query_with_none_partition_key(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 0fa3552442d0..4ec4af5c1406 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -515,131 +515,6 @@ async def test_full_pk_legacy_replay_resumes_same_page_async(self): assert second_page['id'] == replay_second_page['id'] - async def test_full_pk_legacy_bridge_does_not_fallback_on_runtime_error_async(self): - """A non-HTTP error during the wrapped legacy-bridge POST must - propagate as-is on the async path. The 400-only fallback must - NOT swallow runtime exceptions or retry the page on them. - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - await legacy_pager.__anext__() - legacy_token = legacy_pager.continuation_token - assert _decode_token(legacy_token) is None - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - - async def _failing_post(*args, **kwargs): - raise RuntimeError("bridge-runtime-error-async") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(RuntimeError, match='bridge-runtime-error-async'): - await legacy_iterable.by_page(legacy_token).__anext__() - finally: - client_conn._CosmosClientConnection__Post = original_post - - async def test_full_pk_legacy_bridge_does_not_fallback_on_unrelated_service_error_async(self): - """Only 400 BadRequest triggers the wrapped-path fallback. Other - statuses (e.g. 429) must propagate without an old-shape retry - being issued, on the async path. - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - await legacy_pager.__anext__() - legacy_token = legacy_pager.continuation_token - assert _decode_token(legacy_token) is None - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - saw_legacy_fallback_headers = False - - async def _failing_post(*args, **kwargs): - nonlocal saw_legacy_fallback_headers - req_headers = args[3] - if ( - http_constants.HttpHeaders.PartitionKey in req_headers - and req_headers.get(http_constants.HttpHeaders.Continuation) == legacy_token - and http_constants.HttpHeaders.PartitionKeyRangeID not in req_headers - ): - saw_legacy_fallback_headers = True - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.TOO_MANY_REQUESTS, - message="throttled", - ) - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(exceptions.CosmosHttpResponseError): - await legacy_iterable.by_page(legacy_token).__anext__() - assert not saw_legacy_fallback_headers - finally: - client_conn._CosmosClientConnection__Post = original_post - - async def test_full_pk_legacy_bridge_fallback_failure_stamps_checkpoint_async(self): - """When the wrapped POST returns 400 and the one-shot legacy - fallback POST also fails, the original error is surfaced AND a - resumable checkpoint continuation must be stamped on - ``last_response_headers`` (async path). - """ - created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - await created_collection.upsert_item(body={'pk': 'pk', 'id': str(uuid.uuid4())}) - - legacy_iterable = created_collection.query_items( - query='SELECT * from c', - partition_key='pk', - max_item_count=1, - ) - legacy_pager = legacy_iterable.by_page() - await legacy_pager.__anext__() - legacy_token = legacy_pager.continuation_token - assert _decode_token(legacy_token) is None - - client_conn = created_collection.client_connection - original_post = client_conn._CosmosClientConnection__Post - post_call_count = 0 - - async def _failing_post(*args, **kwargs): - nonlocal post_call_count - post_call_count += 1 - if post_call_count == 1: - raise exceptions.CosmosHttpResponseError( - status_code=http_constants.StatusCodes.BAD_REQUEST, - message="legacy bridge compatibility failure", - ) - raise RuntimeError("fallback-post-failed-async") - - client_conn._CosmosClientConnection__Post = _failing_post - try: - with pytest.raises(RuntimeError, match='fallback-post-failed-async'): - await legacy_iterable.by_page(legacy_token).__anext__() - assert post_call_count == 2 - continuation = client_conn.last_response_headers.get(http_constants.HttpHeaders.Continuation) - assert continuation is not None - assert _decode_token(continuation) is None - finally: - client_conn._CosmosClientConnection__Post = original_post - async def test_cross_partition_query_with_none_partition_key_async(self): created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) From 175e3b8a489620f223e6d6c037d10dd55d28a11c Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 11 May 2026 16:18:57 -0500 Subject: [PATCH 09/18] fixing pipeline errors --- .../azure/cosmos/_cosmos_client_connection.py | 17 ++---- .../aio/base_execution_context.py | 1 - .../_routing/feed_range_continuation.py | 50 ++++++----------- .../aio/_cosmos_client_connection_async.py | 17 ++---- .../tests/test_computed_properties.py | 38 +++++++++---- .../tests/test_computed_properties_async.py | 26 ++++++--- .../test_feed_range_continuation_token.py | 50 ----------------- .../tests/test_full_text_policy.py | 42 +++++++++++---- .../tests/test_full_text_policy_async.py | 53 +++++++++++++------ .../tests/test_partition_split_retry_unit.py | 6 +-- .../test_partition_split_retry_unit_async.py | 6 +-- 11 files changed, 147 insertions(+), 159 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 790194e72ba7..65328b16f4b9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3350,12 +3350,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # ``None`` means start from the beginning of the requested # feed range. page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) - # Identity hashes are only needed when an inbound v=1 token must be - # validated, or when the outbound continuation will be a v=1 envelope. - # Compute lazily so the single-partition legacy outbound path pays no - # hashing cost. The token writers handle ``None`` by computing on demand. - query_hash: Optional[str] = None - feedrange_hash: Optional[str] = None # Single shared copy of options for routing-map lookups in this call. # ``get_overlapping_ranges`` does not mutate options; copying once # avoids per-iteration ``dict(options)`` allocations. @@ -3366,13 +3360,16 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # Cache for the input scope's single-partition classification. # We compute it at most once per __QueryFeed call so inbound # bridge-detection, mid-page checkpoint, and end-of-page outbound - # writer all agree even if the PKRANGE cache refreshes mid-call. + # writer all agree even if the PK range cache refreshes mid-call. cached_is_single_partition: Optional[bool] = None def _is_input_scope_single_partition() -> bool: """Return True when the caller input range currently maps to one physical partition. Result is cached for the duration of this __QueryFeed call. + + :returns: True if the input scope maps to a single physical partition. + :rtype: bool """ nonlocal cached_is_single_partition if cached_is_single_partition is None: @@ -3416,8 +3413,6 @@ def _is_input_scope_single_partition() -> bool: resource_id_str, query, feed_range_epk, - expected_query_hash=query_hash, - expected_feedrange_hash=feedrange_hash, ) pagination_state = _FeedRangePaginationState.from_inbound( inbound_token_payload, page_size_hint @@ -3469,8 +3464,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: feed_range_epk, is_full_pk_structured_scope, (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), - query_hash, - feedrange_hash, ) except Exception as continuation_write_error: # pylint: disable=broad-exception-caught _LOGGER.warning( @@ -3624,8 +3617,6 @@ def _checkpoint_and_reraise(error: Exception) -> None: feed_range_epk, is_full_pk_structured_scope, (not is_full_pk_structured_scope) and _is_input_scope_single_partition(), - query_hash, - feedrange_hash, ) # End feed_range pagination block. self.last_response_headers = feedrange_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py index fd356c94d77c..e57dffbccc06 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py @@ -191,7 +191,6 @@ async def callback(**kwargs): # pylint: disable=unused-argument # Refresh routing map to get new partition key ranges self._client.refresh_routing_map_provider() - # Reset execution context state to allow retry from the beginning # Reset execution context state for retry. If __QueryFeed already # stamped a checkpoint continuation on failure, resume from it. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py index a778d2cdd58c..e2e4462d8a04 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/feed_range_continuation.py @@ -265,8 +265,6 @@ def _validate_token_identity( resource_id: str, query: Any, feed_range_epk: routing_range.Range, - expected_query_hash: Optional[str] = None, - expected_feedrange_hash: Optional[str] = None, ) -> None: """Confirm the inbound token was created for the same collection, query, and feed_range the current call is using. If any of the @@ -282,13 +280,9 @@ def _validate_token_identity( :type query: str or dict :param feed_range_epk: Current feed range scope. :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range - :param expected_query_hash: Precomputed query hash to validate against inbound token. - :type expected_query_hash: Optional[str] - :param expected_feedrange_hash: Precomputed feed_range hash to validate against inbound token. - :type expected_feedrange_hash: Optional[str] """ - expected_qh = expected_query_hash or _hash_query_spec(query) - expected_frh = expected_feedrange_hash or _hash_feed_range(feed_range_epk) + expected_qh = _hash_query_spec(query) + expected_frh = _hash_feed_range(feed_range_epk) if inbound[_FIELD_COLLECTION_RID] != resource_id: raise ValueError( "Continuation token was created for a different collection " @@ -320,6 +314,18 @@ def _should_bridge_legacy_continuation( * full-PK structured scope (always single partition), or * non-full-PK scope that currently maps to one physical partition. + + :param inbound_serialized_continuation: The raw inbound continuation token, if any. + :type inbound_serialized_continuation: str or None + :param inbound_token_payload: The decoded structured (``v=1``) token payload, or + ``None`` if the inbound token is legacy/opaque. + :type inbound_token_payload: dict or None + :param bool is_full_pk_structured_scope: True if the request scope is a full + partition key (always single partition). + :param bool is_single_partition_scope: True if the request scope currently maps + to a single physical partition. + :returns: True if the legacy continuation should be bridged into pagination state. + :rtype: bool """ return bool( inbound_serialized_continuation @@ -590,8 +596,6 @@ def write_outbound_continuation( resource_id: str, query: Any, feed_range_epk: routing_range.Range, - query_hash: Optional[str] = None, - feedrange_hash: Optional[str] = None, ) -> None: """Set or clear the outbound continuation header from the queue. @@ -608,10 +612,6 @@ def write_outbound_continuation( :type query: str or dict :param feed_range_epk: Original request feed range. :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range - :param query_hash: Optional precomputed query hash to embed in the outbound token. - :type query_hash: Optional[str] - :param feedrange_hash: Optional precomputed feed_range hash to embed in the outbound token. - :type feedrange_hash: Optional[str] """ if not self.queue: last_response_headers.pop(http_constants.HttpHeaders.Continuation, None) @@ -621,8 +621,6 @@ def write_outbound_continuation( query, feed_range_epk, self.queue, - query_hash=query_hash, - feedrange_hash=feedrange_hash, ) @@ -634,8 +632,6 @@ def _write_query_outbound_continuation( feed_range_epk: routing_range.Range, is_full_pk_structured_scope: bool, emit_legacy_for_single_partition: bool, - query_hash: Optional[str], - feedrange_hash: Optional[str], ) -> None: """Write outbound continuation for feed-range pagination. @@ -660,12 +656,6 @@ def _write_query_outbound_continuation( :param emit_legacy_for_single_partition: Whether non-full-PK scope currently maps to a single physical partition and can safely emit legacy continuation. :type emit_legacy_for_single_partition: bool - :param query_hash: Precomputed query hash for outbound token identity, or ``None`` to - compute lazily inside the structured-token writer when needed. - :type query_hash: Optional[str] - :param feedrange_hash: Precomputed feed range hash for outbound token identity, or - ``None`` to compute lazily inside the structured-token writer when needed. - :type feedrange_hash: Optional[str] :returns: None. Mutates ``last_response_headers`` in place. :rtype: None """ @@ -681,8 +671,6 @@ def _write_query_outbound_continuation( resource_id, query, feed_range_epk, - query_hash=query_hash, - feedrange_hash=feedrange_hash, ) @@ -692,8 +680,6 @@ def _build_outbound_token( query: Any, feed_range_epk: routing_range.Range, entries: Iterable[Tuple[routing_range.Range, Optional[str]]], - query_hash: Optional[str] = None, - feedrange_hash: Optional[str] = None, ) -> str: """Build and base64-encode the outbound continuation token from a queue of ``(range, backend_continuation)`` entries. @@ -708,18 +694,14 @@ def _build_outbound_token( :type feed_range_epk: ~azure.cosmos._routing.routing_range.Range :param entries: Ordered ``(range, bc)`` pairs to serialize. :type entries: Iterable[tuple[~azure.cosmos._routing.routing_range.Range, Optional[str]]] - :param query_hash: Optional precomputed query hash to persist in the token envelope. - :type query_hash: Optional[str] - :param feedrange_hash: Optional precomputed feed_range hash to persist in the token envelope. - :type feedrange_hash: Optional[str] :returns: Encoded continuation token. :rtype: str """ payload = { _FIELD_VERSION: _TOKEN_VERSION, _FIELD_COLLECTION_RID: resource_id, - _FIELD_QUERY_HASH: query_hash or _hash_query_spec(query), - _FIELD_FEEDRANGE_HASH: feedrange_hash or _hash_feed_range(feed_range_epk), + _FIELD_QUERY_HASH: _hash_query_spec(query), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(feed_range_epk), _FIELD_CONTINUATIONS: [ {"min": r.min, "max": r.max, _FIELD_BACKEND_CONTINUATION: bc} for r, bc in entries diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 7f12ddbc3d7a..0213054a5d05 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -3147,12 +3147,6 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # ``None`` means start from the beginning of the requested # feed range. page_size_hint = _normalize_max_item_count(options.get("maxItemCount")) - # Identity hashes are only needed when an inbound v=1 token must be - # validated, or when the outbound continuation will be a v=1 envelope. - # Compute lazily so the single-partition legacy outbound path pays no - # hashing cost. The token writers handle ``None`` by computing on demand. - query_hash: Optional[str] = None - feedrange_hash: Optional[str] = None # Single shared copy of options for routing-map lookups in this call. # ``get_overlapping_ranges`` does not mutate options; copying once # avoids per-iteration ``dict(options)`` allocations. @@ -3163,13 +3157,16 @@ def __GetBodiesFromQueryResult(result: dict[str, Any]) -> list[dict[str, Any]]: # Cache for the input scope's single-partition classification. # We compute it at most once per __QueryFeed call so inbound # bridge-detection, mid-page checkpoint, and end-of-page outbound - # writer all agree even if the PKRANGE cache refreshes mid-call. + # writer all agree even if the PK range cache refreshes mid-call. cached_is_single_partition: Optional[bool] = None async def _is_input_scope_single_partition() -> bool: """Return True when the caller input range currently maps to one physical partition. Result is cached for the duration of this __QueryFeed call. + + :returns: True if the input scope maps to a single physical partition. + :rtype: bool """ nonlocal cached_is_single_partition if cached_is_single_partition is None: @@ -3213,8 +3210,6 @@ async def _is_input_scope_single_partition() -> bool: resource_id_str, query, feed_range_epk, - expected_query_hash=query_hash, - expected_feedrange_hash=feedrange_hash, ) pagination_state = _FeedRangePaginationState.from_inbound( inbound_token_payload, page_size_hint @@ -3265,8 +3260,6 @@ async def _checkpoint_and_reraise(error: Exception) -> None: feed_range_epk, is_full_pk_structured_scope, single_partition_scope_for_outbound, - query_hash, - feedrange_hash, ) except Exception as continuation_write_error: # pylint: disable=broad-exception-caught _LOGGER.warning( @@ -3427,8 +3420,6 @@ async def _checkpoint_and_reraise(error: Exception) -> None: feed_range_epk, is_full_pk_structured_scope, single_partition_scope_for_outbound, - query_hash, - feedrange_hash, ) # End feed_range pagination block. self.last_response_headers = feedrange_response_headers diff --git a/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py b/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py index c41fcd52567f..e6bc283de0e4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py +++ b/sdk/cosmos/azure-cosmos/tests/test_computed_properties.py @@ -167,15 +167,29 @@ def test_replace_with_new_computed_properties(self): # Check that computed properties were properly sent to replaced container self.assertListEqual(new_computed_properties, replaced_collection.read()["computedProperties"]) - # Test 1: Test first computed property - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', - partition_key="test")) + # Test 1: Test first computed property. Allow brief settle time after replace + # for the new computed-property index to be ready. + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_upper = "GROUP2"', + partition_key="test")) + if len(queried_items) == 3: + break + time.sleep(1) self.assertEqual(len(queried_items), 3) - # Test 1 Negative: Test if using non-existent computed property name returns nothing - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', partition_key="test")) + # Test 1 Negative: Test if using non-existent computed property name returns nothing. + # After a container replace the backend may take a moment to drop the old + # computed property index ("cp_lower" from before the replace), so retry briefly. + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', + partition_key="test")) + if len(queried_items) == 0: + break + time.sleep(1) self.assertEqual(len(queried_items), 0) # Test 2: Test Second Computed Property @@ -184,8 +198,14 @@ def test_replace_with_new_computed_properties(self): self.assertEqual(len(queried_items), 2) # Test 2 Negative: Test Str length using old computed properties name - queried_items = list( - replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', partition_key="test")) + queried_items = [] + for _ in range(10): + queried_items = list( + replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', + partition_key="test")) + if len(queried_items) == 0: + break + time.sleep(1) self.assertEqual(len(queried_items), 0) self.created_db.delete_container(created_collection.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py index 7eaafb7a9e37..1f48b87558a0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_computed_properties_async.py @@ -195,10 +195,17 @@ async def test_replace_with_new_computed_properties_async(self): # Test 1: Test first computed property self.assertEqual(len(queried_items), 3) - # Test 1 Negative: Test if using non-existent computed property name returns nothing - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', - partition_key="test")] + # Test 1 Negative: Test if using non-existent computed property name returns nothing. + # After a container replace the backend may take a moment to drop the old + # computed property index ("cp_lower" from before the replace), so retry briefly. + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_lower = "group1"', + partition_key="test")] + if len(queried_items) == 0: + break + await asyncio.sleep(1) self.assertEqual(len(queried_items), 0) # Test 2: Test Second Computed Property @@ -208,9 +215,14 @@ async def test_replace_with_new_computed_properties_async(self): self.assertEqual(len(queried_items), 2) # Test 2 Negative: Test Str length using old computed properties name - queried_items = [q async for q in - replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', - partition_key="test")] + queried_items = [] + for _ in range(10): + queried_items = [q async for q in + replaced_collection.query_items(query='Select * from c Where c.cp_str_len = 9', + partition_key="test")] + if len(queried_items) == 0: + break + await asyncio.sleep(1) self.assertEqual(len(queried_items), 0) await self.created_db.delete_container(created_collection.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index 8be26c5f5e09..303a4f125556 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -228,28 +228,6 @@ def test_build_outbound_token_emits_valid_token(self): assert "cf" not in decoded assert "rf" not in decoded - def test_build_outbound_token_uses_precomputed_hashes_without_rehash(self, monkeypatch): - monkeypatch.setattr( - "azure.cosmos._routing.feed_range_continuation._hash_query_spec", - lambda _query: (_ for _ in ()).throw(AssertionError("query hash should not be recomputed")), - ) - monkeypatch.setattr( - "azure.cosmos._routing.feed_range_continuation._hash_feed_range", - lambda _feed_range: (_ for _ in ()).throw(AssertionError("feed_range hash should not be recomputed")), - ) - - wire = _build_outbound_token( - resource_id=_RID, - query=_QUERY, - feed_range_epk=_FEED_RANGE, - entries=[(_HEAD_FEEDRANGE, _BACKEND_CONT)], - query_hash="precomputed-query-hash", - feedrange_hash="precomputed-feedrange-hash", - ) - decoded = _decode_token(wire) - assert decoded is not None - assert decoded[_FIELD_QUERY_HASH] == "precomputed-query-hash" - assert decoded[_FIELD_FEEDRANGE_HASH] == "precomputed-feedrange-hash" def test_per_entry_backend_continuations_coexist(self): # The shape that motivated the flat ``c`` list: a future @@ -553,28 +531,6 @@ def test_call_site_replay_with_different_feed_range_raises(self): ) assert "feed_range" in str(excinfo.value).lower() - def test_validate_token_identity_uses_precomputed_hashes_without_rehash(self, monkeypatch): - payload = _make_valid_token_payload() - inbound = _decode_token(_encode_token(payload)) - assert inbound is not None - - monkeypatch.setattr( - "azure.cosmos._routing.feed_range_continuation._hash_query_spec", - lambda _query: (_ for _ in ()).throw(AssertionError("query hash should not be recomputed")), - ) - monkeypatch.setattr( - "azure.cosmos._routing.feed_range_continuation._hash_feed_range", - lambda _feed_range: (_ for _ in ()).throw(AssertionError("feed_range hash should not be recomputed")), - ) - - _validate_token_identity( - inbound, - resource_id=_RID, - query=_QUERY, - feed_range_epk=_FEED_RANGE, - expected_query_hash=inbound[_FIELD_QUERY_HASH], - expected_feedrange_hash=inbound[_FIELD_FEEDRANGE_HASH], - ) # ---------------------------------------------------------------------- # @@ -1547,8 +1503,6 @@ def test_full_pk_scope_always_emits_legacy(self): _FEED_RANGE, is_full_pk_structured_scope=True, emit_legacy_for_single_partition=False, - query_hash=_hash_query_spec(_QUERY), - feedrange_hash=_hash_feed_range(_FEED_RANGE), ) assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT @@ -1569,8 +1523,6 @@ def test_non_full_pk_single_partition_scope_emits_legacy(self): _FEED_RANGE, is_full_pk_structured_scope=False, emit_legacy_for_single_partition=True, - query_hash=_hash_query_spec(_QUERY), - feedrange_hash=_hash_feed_range(_FEED_RANGE), ) assert headers[http_constants.HttpHeaders.Continuation] == _BACKEND_CONT @@ -1590,8 +1542,6 @@ def test_non_full_pk_multi_partition_scope_emits_structured(self): _FEED_RANGE, is_full_pk_structured_scope=False, emit_legacy_for_single_partition=False, - query_hash=_hash_query_spec(_QUERY), - feedrange_hash=_hash_feed_range(_FEED_RANGE), ) decoded = _decode_token(headers[http_constants.HttpHeaders.Continuation]) diff --git a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py index 30e9792b5783..4f1703293a10 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py +++ b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py @@ -11,6 +11,26 @@ from azure.cosmos import CosmosClient, PartitionKey +def _assert_full_text_policy_matches(returned, expected): + """Service may augment the returned policy with extra fields (e.g. ``defaultSpec``). + Only assert the fields the caller actually specified are present and equal. + """ + for key, value in expected.items(): + assert returned.get(key) == value, ( + f"fullTextPolicy field {key!r} mismatch: expected {value!r}, got {returned.get(key)!r}" + ) + + +def _assert_message_contains(message, *needles): + """Case-insensitive substring check tolerant of minor service message wording + changes (e.g. ``"Full Text Policy"`` -> ``"full-text policy"``, + ``abstract`` -> ``'abstract'``).""" + haystack = message.lower().replace("'", "").replace("-", " ") + for needle in needles: + n = needle.lower().replace("'", "").replace("-", " ") + assert n in haystack, f"expected {needle!r} in error message, got: {message!r}" + + @pytest.mark.cosmosSearchQuery class TestFullTextPolicy(unittest.TestCase): client: CosmosClient = None @@ -54,8 +74,8 @@ def test_create_full_text_container(self): indexing_policy=indexing_policy ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) self.test_db.delete_container(created_container.id) # Create a container with a full text policy containing only default language @@ -68,7 +88,7 @@ def test_create_full_text_container(self): full_text_policy=full_text_policy_no_paths, ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_paths + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_paths) self.test_db.delete_container(created_container.id) # Create a container with a full text policy with a given path containing only default language @@ -86,7 +106,7 @@ def test_create_full_text_container(self): full_text_policy=full_text_policy_no_langs, ) properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_langs + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_langs) self.test_db.delete_container(created_container.id) def test_replace_full_text_container(self): @@ -120,7 +140,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) properties = replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['indexingPolicy'] != properties['indexingPolicy'] self.test_db.delete_container(created_container.id) @@ -133,7 +153,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) created_container_properties = created_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] # Replace the container with new policies @@ -146,7 +166,7 @@ def test_replace_full_text_container(self): indexing_policy=indexing_policy ) properties = replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['fullTextPolicy'] != properties['fullTextPolicy'] assert created_container_properties["indexingPolicy"] != properties["indexingPolicy"] @@ -193,8 +213,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:"\ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") # Pass a full text policy with an unsupported path language full_text_policy_wrong_default = { @@ -215,8 +234,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:"\ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") def test_fail_create_full_text_indexing_policy(self): full_text_policy = { @@ -282,7 +300,9 @@ def test_fail_create_full_text_indexing_policy(self): pytest.fail("Container creation should have failed for missing path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Missing path in full-text index specification at index (0)" in e.http_error_message + _assert_message_contains( + e.http_error_message, "missing path", "full text index" + ) if __name__ == '__main__': diff --git a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py index f52092b102a9..72a081b181ac 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy_async.py @@ -13,6 +13,26 @@ from azure.cosmos.aio import CosmosClient +def _assert_full_text_policy_matches(returned, expected): + """Service may augment the returned policy with extra fields (e.g. ``defaultSpec``). + Only assert the fields the caller actually specified are present and equal. + """ + for key, value in expected.items(): + assert returned.get(key) == value, ( + f"fullTextPolicy field {key!r} mismatch: expected {value!r}, got {returned.get(key)!r}" + ) + + +def _assert_message_contains(message, *needles): + """Case-insensitive substring check tolerant of minor service message wording + changes (e.g. ``"Full Text Policy"`` -> ``"full-text policy"``, + ``abstract`` -> ``'abstract'``).""" + haystack = message.lower().replace("'", "").replace("-", " ") + for needle in needles: + n = needle.lower().replace("'", "").replace("-", " ") + assert n in haystack, f"expected {needle!r} in error message, got: {message!r}" + + @pytest.mark.cosmosSearchQuery class TestFullTextPolicyAsync(unittest.IsolatedAsyncioTestCase): host = test_config.TestConfig.host @@ -69,7 +89,7 @@ async def test_create_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] await self.test_db.delete_container(created_container.id) @@ -83,7 +103,7 @@ async def test_create_full_text_container_async(self): full_text_policy=full_text_policy_no_paths, ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_paths + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_paths) await self.test_db.delete_container(created_container.id) # Create a container with a full text policy with a given path containing only default language @@ -101,7 +121,7 @@ async def test_create_full_text_container_async(self): full_text_policy=full_text_policy_no_langs, ) properties = await created_container.read() - assert properties["fullTextPolicy"] == full_text_policy_no_langs + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy_no_langs) async def test_replace_full_text_container_async(self): # Replace a container without a full text policy and full text indexing policy @@ -134,7 +154,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['indexingPolicy'] != properties['indexingPolicy'] await self.test_db.delete_container(created_container.id) @@ -147,7 +167,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) created_container_properties = await created_container.read() - assert created_container_properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(created_container_properties["fullTextPolicy"], full_text_policy) assert created_container_properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] # Replace the container with new policies @@ -160,7 +180,7 @@ async def test_replace_full_text_container_async(self): indexing_policy=indexing_policy ) properties = await replaced_container.read() - assert properties["fullTextPolicy"] == full_text_policy + _assert_full_text_policy_matches(properties["fullTextPolicy"], full_text_policy) assert properties["indexingPolicy"]['fullTextIndexes'] == indexing_policy['fullTextIndexes'] assert created_container_properties['fullTextPolicy'] != properties['fullTextPolicy'] assert created_container_properties["indexingPolicy"] != properties["indexingPolicy"] @@ -186,7 +206,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an invalid Path: abstract" in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "invalid path", "abstract") # Pass a full text policy with an unsupported default language full_text_policy_wrong_default = { @@ -207,8 +227,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:" \ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") # Pass a full text policy with an unsupported path language full_text_policy_wrong_default = { @@ -229,8 +248,7 @@ async def test_fail_create_full_text_policy_async(self): pytest.fail("Container creation should have failed for wrong supported language.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an unsupported language spa-SPA. Supported languages are:" \ - in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "unsupported language", "spa-SPA") async def test_fail_create_full_text_indexing_policy_async(self): full_text_policy = { @@ -259,8 +277,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): # pytest.fail("Container creation should have failed for lack of embedding policy.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The path of the Full Text Index /path does not match the path specified in the Full Text Policy" \ - in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "/path", "does not match", "full text policy" + ) # Pass a full text indexing policy with a wrongly formatted path indexing_policy_wrong_path = { @@ -278,7 +297,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Full-text index specification at index (0) contains invalid path" in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "invalid path" + ) # Pass a full text indexing policy without a path field indexing_policy_no_path = { @@ -296,7 +317,9 @@ async def test_fail_create_full_text_indexing_policy_async(self): pytest.fail("Container creation should have failed for missing path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Missing path in full-text index specification at index (0)" in e.http_error_message + _assert_message_contains( + e.http_error_message, "missing path", "full text index" + ) if __name__ == '__main__': diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index ff2a73103bef..194a6cbea2ec 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -640,7 +640,7 @@ def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" - def test_querfeed_populates_capture_dict_from_options(self): + def test_queryfeed_populates_capture_dict_from_options(self): """`__QueryFeed` must read the capture dict from `options` and populate it from the underlying response headers. @@ -669,7 +669,7 @@ def test_querfeed_populates_capture_dict_from_options(self): "_internal_response_headers_capture": capture_dict, } - canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-querfeed"} + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-queryfeed"} request_obj_mock = MagicMock( set_excluded_location_from_options=MagicMock(), @@ -715,7 +715,7 @@ def test_querfeed_populates_capture_dict_from_options(self): assert mock_get.called, "expected __Get to be invoked on the no-query path" - assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-querfeed", ( + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-queryfeed", ( f"capture dict was not populated by __QueryFeed; got {capture_dict!r}. " "This indicates __QueryFeed is not reading " "'_internal_response_headers_capture' from options." diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index 42d543e88fe1..8642d37b87de 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -638,7 +638,7 @@ async def mock_fetch_function(options): assert pk_refresh_calls == 0, \ f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" - async def test_querfeed_populates_capture_dict_from_options_async(self): + async def test_queryfeed_populates_capture_dict_from_options_async(self): """Async `__QueryFeed` must read the capture dict from `options` and populate it from the underlying response headers, with no test-side injection. Catches the `options`-vs-`kwargs` @@ -665,7 +665,7 @@ async def test_querfeed_populates_capture_dict_from_options_async(self): "_internal_response_headers_capture": capture_dict, } - canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-querfeed-async"} + canned_headers = {HttpHeaders.Continuation: "checkpoint-from-real-queryfeed-async"} request_obj_mock = MagicMock( set_excluded_location_from_options=MagicMock(), @@ -711,7 +711,7 @@ async def test_querfeed_populates_capture_dict_from_options_async(self): assert mock_get.await_count >= 1, "expected __Get to be awaited on the no-query path" - assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-querfeed-async", ( + assert capture_dict.get(HttpHeaders.Continuation) == "checkpoint-from-real-queryfeed-async", ( f"capture dict was not populated by async __QueryFeed; got {capture_dict!r}. " "This indicates async __QueryFeed is not reading " "'_internal_response_headers_capture' from options." From 490fa5bc44db268cdf406016cfa3eb01bcd69a01 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 13 May 2026 10:45:53 -0500 Subject: [PATCH 10/18] fixing pylint error --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 3 +++ .../azure/cosmos/aio/_cosmos_client_connection_async.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 65328b16f4b9..0682f3e05e7b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3184,6 +3184,9 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma :type response_hook: Callable[[Mapping[str, Any], dict[str, Any]], None] :param bool is_query_plan: Specifies if the call is to fetch query plan + :param response_headers_list: + Optional list to which per-request response headers will be appended. + :type response_headers_list: Optional[list[CaseInsensitiveDict]] :returns: A list of the queried resources. :rtype: list :raises SystemError: If the query compatibility mode is undefined. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 0213054a5d05..3fd1541e21ca 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2987,6 +2987,9 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, :type response_hook: Callable[[Mapping[str, Any], dict[str, Any]], None] :param bool is_query_plan: Specifies if the call is to fetch query plan + :param response_headers_list: + Optional list to which per-request response headers will be appended. + :type response_headers_list: Optional[list[CaseInsensitiveDict]] :returns: A list of the queried resources. :rtype: list :raises SystemError: If the query compatibility mode is undefined. From d15829151c27e8839dc40a82f4131ed9d730ba19 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 13 May 2026 14:42:39 -0700 Subject: [PATCH 11/18] Fix test assertions for updated full-text policy service error wording Service-side error messages changed wording (e.g., 'Full Text Policy' -> 'full-text policy', 'abstract' -> "'abstract'"). Replace 3 brittle string-match assertions with the existing _assert_message_contains helper, which normalizes case, single quotes, and hyphens. The async test file already uses this helper consistently; this brings the sync file in line. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../azure-cosmos/tests/test_full_text_policy.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py index 4f1703293a10..272763f07cc6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py +++ b/sdk/cosmos/azure-cosmos/tests/test_full_text_policy.py @@ -192,7 +192,7 @@ def test_fail_create_full_text_policy(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The Full Text Policy contains an invalid Path: abstract" in e.http_error_message + _assert_message_contains(e.http_error_message, "full text policy", "invalid path", "abstract") # Pass a full text policy with an unsupported default language full_text_policy_wrong_default = { @@ -263,8 +263,9 @@ def test_fail_create_full_text_indexing_policy(self): # pytest.fail("Container creation should have failed for lack of embedding policy.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "The path of the Full Text Index /path does not match the path specified in the Full Text Policy"\ - in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "/path", "does not match", "full text policy" + ) # Pass a full text indexing policy with a wrongly formatted path indexing_policy_wrong_path = { @@ -282,7 +283,9 @@ def test_fail_create_full_text_indexing_policy(self): pytest.fail("Container creation should have failed for invalid path.") except exceptions.CosmosHttpResponseError as e: assert e.status_code == 400 - assert "Full-text index specification at index (0) contains invalid path" in e.http_error_message + _assert_message_contains( + e.http_error_message, "full text index", "invalid path" + ) # Pass a full text indexing policy without a path field indexing_policy_no_path = { From ea824cf00c02dbf2e4a7265f0fbe61bab800761f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?McCoy=20Pati=C3=B1o?= <39780829+mccoyp@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:30:38 -0800 Subject: [PATCH 12/18] Remove remaining imports of pkg_resources (#45084) * Remove remaining imports of pkg_resources * Restore comment explaining opentelemetry_version + cleanup --- sdk/evaluation/azure-ai-evaluation/setup.py | 1 - sdk/ml/azure-ai-ml/setup.py | 1 - .../azure/monitor/opentelemetry/exporter/_utils.py | 14 ++------------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 5253c94fa865..33ba4187265c 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -9,7 +9,6 @@ from io import open from typing import Any, Match, cast -import pkg_resources from setuptools import find_packages, setup # Change the PACKAGE_NAME only to change folder and different name diff --git a/sdk/ml/azure-ai-ml/setup.py b/sdk/ml/azure-ai-ml/setup.py index 83c0d47fe922..9e472e9c6dab 100644 --- a/sdk/ml/azure-ai-ml/setup.py +++ b/sdk/ml/azure-ai-ml/setup.py @@ -7,7 +7,6 @@ from io import open from typing import Any, Match, cast -import pkg_resources from setuptools import find_packages, setup # Change the PACKAGE_NAME only to change folder and different name diff --git a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py index 48399297eaef..99b16617833e 100644 --- a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py +++ b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import datetime +from importlib.metadata import version import locale from os import environ from os.path import isdir @@ -35,19 +36,8 @@ _RP_Names, ) -opentelemetry_version = "" - # Workaround for missing version file -try: - from importlib.metadata import version - - opentelemetry_version = version("opentelemetry-sdk") -except ImportError: - # Temporary workaround for Date: Wed, 13 May 2026 16:52:42 -0700 Subject: [PATCH 13/18] Revert "Remove remaining imports of pkg_resources (#45084)" This reverts commit ea824cf00c02dbf2e4a7265f0fbe61bab800761f. --- sdk/evaluation/azure-ai-evaluation/setup.py | 1 + sdk/ml/azure-ai-ml/setup.py | 1 + .../azure/monitor/opentelemetry/exporter/_utils.py | 14 ++++++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 33ba4187265c..5253c94fa865 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -9,6 +9,7 @@ from io import open from typing import Any, Match, cast +import pkg_resources from setuptools import find_packages, setup # Change the PACKAGE_NAME only to change folder and different name diff --git a/sdk/ml/azure-ai-ml/setup.py b/sdk/ml/azure-ai-ml/setup.py index 9e472e9c6dab..83c0d47fe922 100644 --- a/sdk/ml/azure-ai-ml/setup.py +++ b/sdk/ml/azure-ai-ml/setup.py @@ -7,6 +7,7 @@ from io import open from typing import Any, Match, cast +import pkg_resources from setuptools import find_packages, setup # Change the PACKAGE_NAME only to change folder and different name diff --git a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py index 99b16617833e..48399297eaef 100644 --- a/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py +++ b/sdk/monitor/azure-monitor-opentelemetry-exporter/azure/monitor/opentelemetry/exporter/_utils.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import datetime -from importlib.metadata import version import locale from os import environ from os.path import isdir @@ -36,8 +35,19 @@ _RP_Names, ) +opentelemetry_version = "" + # Workaround for missing version file -opentelemetry_version = version("opentelemetry-sdk") +try: + from importlib.metadata import version + + opentelemetry_version = version("opentelemetry-sdk") +except ImportError: + # Temporary workaround for Date: Wed, 13 May 2026 16:53:29 -0700 Subject: [PATCH 14/18] Make discover_targeted_packages tolerate setup.py exec failures Some packages (azure-ai-ml, azure-ai-evaluation) have setup.py files that import pkg_resources at module level. When the build venv has setuptools>=80 (where pkg_resources is gone), exec'ing those setup.py files raises ModuleNotFoundError, which discover_targeted_packages did not catch. This aborted the entire 'Ensure service coverage' regression- matrix step for cosmos hotfix builds. Broaden the exception handler from RuntimeError to Exception so a single bad setup.py omits just that package from discovery instead of aborting the run. The error reason is logged for traceability. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- eng/tools/azure-sdk-tools/ci_tools/functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/eng/tools/azure-sdk-tools/ci_tools/functions.py b/eng/tools/azure-sdk-tools/ci_tools/functions.py index 2ef7c99779f0..4058652af072 100644 --- a/eng/tools/azure-sdk-tools/ci_tools/functions.py +++ b/eng/tools/azure-sdk-tools/ci_tools/functions.py @@ -260,8 +260,12 @@ def discover_targeted_packages( for pkg in collected_packages: try: parsed_packages.append(ParsedSetup.from_path(pkg)) - except RuntimeError as e: - logging.error(f"Unable to parse metadata for package {pkg}, omitting from build.") + except Exception as e: + # Some packages have setup.py files that import modules unavailable in the + # current environment (e.g. pkg_resources removed by setuptools>=80). Such + # packages should be omitted from the build/regression set rather than + # aborting discovery for the entire repo. + logging.error(f"Unable to parse metadata for package {pkg}, omitting from build. Reason: {e}") continue # filter for compatibility, this means excluding a package that doesn't support py36 when we are running a py36 executable From 2e5ae478a84abfb45bcfa76e93fb057b851ede49 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 13 May 2026 17:01:11 -0700 Subject: [PATCH 15/18] Fix combined-sdist lookup for PEP 625 normalized filenames The conda assembly's create_combined_sdist function looks for the sdist file produced by 'python setup.py sdist' inside the assembled folder. With setuptools >= 69 the sdist filename is normalized per PEP 625 (e.g. 'azure_core-1.35.0.tar.gz'), so the substring check '"azure-core" in a' no longer matches and the next() call raises StopIteration, aborting the entire 'Assemble Conda Packages' step. Match the canonicalized form by replacing '-' with '_' before the substring check, identical to the fix already in main (PR #44009). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py index 83bfed097991..945c24e37eca 100644 --- a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py +++ b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py @@ -294,7 +294,7 @@ def create_combined_sdist( [ os.path.join(config_assembled_folder, a) for a in os.listdir(config_assembled_folder) - if os.path.isfile(os.path.join(config_assembled_folder, a)) and conda_build.name in a + if os.path.isfile(os.path.join(config_assembled_folder, a)) and conda_build.name.replace("-", "_") in a ] ) ) From 147d9cd9e2c643b84bbff0fdc014f3dd496c0dc3 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 13 May 2026 20:33:33 -0700 Subject: [PATCH 16/18] Use prefix.dev/conda-forge channel mirror The conda 'env create' step now reaches conda.anaconda.org, which is not in the allowlist for the network-isolated 1ES build agents (requests get sinkholed to 192.0.2.11). Switch the conda env channel to https://prefix.dev/conda-forge, matching the workaround applied upstream on the conda_release_patch branch (commit 1b191674f7). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py index 945c24e37eca..f581049981f4 100644 --- a/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py +++ b/eng/tools/azure-sdk-tools/ci_tools/conda/conda_functions.py @@ -37,7 +37,7 @@ CONDA_ENV_FILE = """name: azure-build-env channels: - - conda-forge + - https://prefix.dev/conda-forge - defaults dependencies: - python=3.10 From 478624feca512a668080fb897d5475b4c83e3c00 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 13 May 2026 21:20:19 -0700 Subject: [PATCH 17/18] Use prefix.dev/conda-forge for azure-ai-ml conda recipe channel azure-ai-ml is the only conda recipe in the conda-sdk-client pipeline with an explicit additional channel (conda-forge). With network isolation blocking conda.anaconda.org, conda-build fetches against that channel fail with 'HTTP 000 CONNECTION FAILED'. Switch to the prefix.dev mirror so the package's transitive dependencies resolve through the same allowlisted host the build env already uses. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- eng/pipelines/templates/stages/conda-sdk-client.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/pipelines/templates/stages/conda-sdk-client.yml b/eng/pipelines/templates/stages/conda-sdk-client.yml index 5e5509a64fa2..c4d255b5e49b 100644 --- a/eng/pipelines/templates/stages/conda-sdk-client.yml +++ b/eng/pipelines/templates/stages/conda-sdk-client.yml @@ -256,7 +256,7 @@ extends: service: ml in_batch: ${{ parameters.release_azure_ai_ml }} channels: - - conda-forge + - https://prefix.dev/conda-forge checkout: - package: azure-ai-ml version: 1.28.1 From 9923ac6bba47d1fcab42322dc1acf1d3e698db93 Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Wed, 13 May 2026 23:17:06 -0700 Subject: [PATCH 18/18] Pin setuptools<70 in ai-ml/evaluation conda recipes setuptools 81+ removed pkg_resources, which both azure-ai-ml/setup.py and azure-ai-evaluation/setup.py still import unconditionally. The conda build env now ships setuptools 82.0.1 from prefix.dev/conda-forge, so `pip install .` fails with ModuleNotFoundError: No module named 'pkg_resources'. Pin setuptools<70 in host requirements and use --no-build-isolation so pip uses the host env's older setuptools (which still bundles pkg_resources) instead of pulling latest into an isolated build env. This is a conda-recipe-only change so it does not trigger pull request CI for the azure-ai-ml or azure-ai-evaluation packages themselves. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- conda/conda-recipes/azure-ai-evaluation/meta.yaml | 4 +++- conda/conda-recipes/azure-ai-ml/meta.yaml | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/conda/conda-recipes/azure-ai-evaluation/meta.yaml b/conda/conda-recipes/azure-ai-evaluation/meta.yaml index c558b0e6b9e3..0a15747b34b6 100644 --- a/conda/conda-recipes/azure-ai-evaluation/meta.yaml +++ b/conda/conda-recipes/azure-ai-evaluation/meta.yaml @@ -10,10 +10,12 @@ source: build: noarch: python number: 0 - script: "{{ PYTHON }} -m pip install . -vv" + script: "{{ PYTHON }} -m pip install . -vv --no-build-isolation" requirements: host: + - setuptools <70 + - wheel - azure-core >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - azure-identity >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - msrest >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} diff --git a/conda/conda-recipes/azure-ai-ml/meta.yaml b/conda/conda-recipes/azure-ai-ml/meta.yaml index b9cad7784dd5..cf04bb304bc4 100644 --- a/conda/conda-recipes/azure-ai-ml/meta.yaml +++ b/conda/conda-recipes/azure-ai-ml/meta.yaml @@ -10,10 +10,12 @@ source: build: noarch: python number: 0 - script: "{{ PYTHON }} -m pip install . -vv" + script: "{{ PYTHON }} -m pip install . -vv --no-build-isolation" requirements: host: + - setuptools <70 + - wheel - azure-core >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - azure-identity >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }} - msrest >={{ environ.get('AZURESDK_CONDA_VERSION', '0.0.0') }}