|
24 | 24 | from copy import copy |
25 | 25 | from dataclasses import dataclass |
26 | 26 | from enum import Enum |
27 | | -from functools import cached_property, singledispatch |
| 27 | +from functools import cached_property, partial, singledispatch |
28 | 28 | from itertools import chain |
29 | 29 | from typing import ( |
30 | 30 | TYPE_CHECKING, |
|
39 | 39 | Optional, |
40 | 40 | Set, |
41 | 41 | Tuple, |
| 42 | + Type, |
42 | 43 | TypeVar, |
43 | 44 | Union, |
44 | 45 | ) |
45 | 46 |
|
46 | 47 | from pydantic import Field, field_validator |
47 | 48 | from sortedcontainers import SortedList |
| 49 | +from tenacity import ( |
| 50 | + RetryError, |
| 51 | + Retrying, |
| 52 | + retry_if_exception_type, |
| 53 | + stop_after_attempt, |
| 54 | + stop_after_delay, |
| 55 | + wait_exponential, |
| 56 | +) |
48 | 57 | from typing_extensions import Annotated |
49 | 58 |
|
50 | 59 | import pyiceberg.expressions.parser as parser |
51 | 60 | import pyiceberg.expressions.visitors as visitors |
52 | | -from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError |
| 61 | +from pyiceberg.exceptions import ( |
| 62 | + CommitFailedException, |
| 63 | + ResolveError, |
| 64 | + ValidationError, |
| 65 | +) |
53 | 66 | from pyiceberg.expressions import ( |
54 | 67 | AlwaysTrue, |
55 | 68 | And, |
@@ -947,6 +960,97 @@ class CommitTableResponse(IcebergBaseModel): |
947 | 960 | metadata_location: str = Field(alias="metadata-location") |
948 | 961 |
|
949 | 962 |
|
| 963 | +class CommitTableRetryableExceptions: |
| 964 | + """A catalogs commit exceptions that are retryable.""" |
| 965 | + |
| 966 | + def __init__(self, retry_exceptions: tuple[Type[Exception], ...], retry_refresh_exceptions: tuple[Type[Exception], ...]): |
| 967 | + self.retry_exceptions: tuple[Type[Exception], ...] = retry_exceptions |
| 968 | + self.retry_refresh_exceptions: tuple[Type[Exception], ...] = retry_refresh_exceptions |
| 969 | + self.all: tuple[Type[Exception], ...] = tuple(set(retry_exceptions).union(retry_refresh_exceptions)) |
| 970 | + |
| 971 | + |
| 972 | +class TableCommitRetry: |
| 973 | + """Decorator for building the table commit retry controller.""" |
| 974 | + |
| 975 | + num_retries = "commit.retry.num-retries" |
| 976 | + num_retries_default: int = 4 |
| 977 | + min_wait_ms = "commit.retry.min-wait-ms" |
| 978 | + min_wait_ms_default: int = 100 |
| 979 | + max_wait_ms = "commit.retry.max-wait-ms" |
| 980 | + max_wait_ms_default: int = 60000 # 1 min |
| 981 | + total_timeout_ms = "commit.retry.total-timeout-ms" |
| 982 | + total_timeout_ms_default: int = 1800000 # 30 mins |
| 983 | + |
| 984 | + properties_attr: str = "properties" |
| 985 | + refresh_attr: str = "refresh" |
| 986 | + commit_retry_exceptions_attr: str = "commit_retry_exceptions" |
| 987 | + |
| 988 | + def __init__(self, func: Callable[..., Any]) -> None: |
| 989 | + self.func: Callable[..., Any] = func |
| 990 | + self.loaded_properties: Properties = {} |
| 991 | + self.loaded_exceptions: CommitTableRetryableExceptions = CommitTableRetryableExceptions((), ()) |
| 992 | + |
| 993 | + def __get__(self, instance: Any, owner: Any) -> Callable[..., Any]: |
| 994 | + """Return the __call__ method with the instance caller.""" |
| 995 | + return partial(self.__call__, instance) |
| 996 | + |
| 997 | + def __call__(self, instance: Table, *args: Any, **kwargs: Any) -> Any: |
| 998 | + """Run function with the retrying controller on the caller instance.""" |
| 999 | + self.loaded_properties = getattr(instance, self.properties_attr) |
| 1000 | + self.loaded_exceptions = getattr(instance, self.commit_retry_exceptions_attr) |
| 1001 | + previous_attempt_error = None |
| 1002 | + try: |
| 1003 | + for attempt in self.build_retry_controller(): |
| 1004 | + with attempt: |
| 1005 | + # Refresh table is previous exception requires a refresh |
| 1006 | + if previous_attempt_error in self.loaded_exceptions.retry_refresh_exceptions: |
| 1007 | + self.refresh_table(instance) |
| 1008 | + |
| 1009 | + result = self.func(instance, *args, **kwargs) |
| 1010 | + |
| 1011 | + # Grab exception from the attempt |
| 1012 | + outcome = attempt.retry_state.outcome |
| 1013 | + previous_attempt_error = type(outcome.exception()) if outcome.failed else None |
| 1014 | + |
| 1015 | + except RetryError as err: |
| 1016 | + raise Exception from err.reraise() |
| 1017 | + else: |
| 1018 | + return result |
| 1019 | + |
| 1020 | + def build_retry_controller(self) -> Retrying: |
| 1021 | + """Build the retry controller.""" |
| 1022 | + return Retrying( |
| 1023 | + stop=( |
| 1024 | + stop_after_attempt(self.get_config(self.num_retries, self.num_retries_default)) |
| 1025 | + | stop_after_delay( |
| 1026 | + datetime.timedelta(milliseconds=self.get_config(self.total_timeout_ms, self.total_timeout_ms_default)) |
| 1027 | + ) |
| 1028 | + ), |
| 1029 | + wait=wait_exponential( |
| 1030 | + min=self.get_config(self.min_wait_ms, self.min_wait_ms_default) / 1000.0, |
| 1031 | + max=self.get_config(self.max_wait_ms, self.max_wait_ms_default) / 1000.0, |
| 1032 | + ), |
| 1033 | + retry=retry_if_exception_type(self.loaded_exceptions.all), |
| 1034 | + ) |
| 1035 | + |
| 1036 | + def get_config(self, config: str, default: int) -> int: |
| 1037 | + """Get config out of the properties.""" |
| 1038 | + return self.to_int(value, default, config) if (value := self.loaded_properties.get(config)) else default |
| 1039 | + |
| 1040 | + def refresh_table(self, instance: Table) -> None: |
| 1041 | + getattr(instance, self.refresh_attr)() |
| 1042 | + return |
| 1043 | + |
| 1044 | + @staticmethod |
| 1045 | + def to_int(v: str, default: int, config: str) -> int: |
| 1046 | + """Convert str value to int, otherwise return a default.""" |
| 1047 | + try: |
| 1048 | + return int(v) |
| 1049 | + except (ValueError, TypeError): |
| 1050 | + warnings.warn(f"Expected an integer for table property {config}, got: {v}", category=UserWarning) |
| 1051 | + return default |
| 1052 | + |
| 1053 | + |
950 | 1054 | class Table: |
951 | 1055 | identifier: Identifier = Field() |
952 | 1056 | metadata: TableMetadata |
@@ -1188,6 +1292,12 @@ def refs(self) -> Dict[str, SnapshotRef]: |
1188 | 1292 | """Return the snapshot references in the table.""" |
1189 | 1293 | return self.metadata.refs |
1190 | 1294 |
|
| 1295 | + @property |
| 1296 | + def commit_retry_exceptions(self) -> CommitTableRetryableExceptions: |
| 1297 | + """Return the commit exceptions that can be retried on the catalog.""" |
| 1298 | + return self.catalog._accepted_commit_retry_exceptions() # pylint: disable=W0212 |
| 1299 | + |
| 1300 | + @TableCommitRetry |
1191 | 1301 | def _do_commit(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...]) -> None: |
1192 | 1302 | response = self.catalog._commit_table( # pylint: disable=W0212 |
1193 | 1303 | CommitTableRequest( |
@@ -1702,7 +1812,8 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: |
1702 | 1812 | visit_with_partner( |
1703 | 1813 | Catalog._convert_schema_if_needed(new_schema), |
1704 | 1814 | -1, |
1705 | | - UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore |
| 1815 | + UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), |
| 1816 | + # type: ignore |
1706 | 1817 | PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), |
1707 | 1818 | ) |
1708 | 1819 | return self |
|
0 commit comments