diff --git a/etc/config.properties b/etc/config.properties index 38c58875..bd517b61 100644 --- a/etc/config.properties +++ b/etc/config.properties @@ -16,5 +16,6 @@ protocol.spooling.retrieval-mode=coordinator_proxy # Enable dynamic catalog management catalog.management=dynamic -# Disable http request log -http-server.log.enabled=false +# Enable HTTP request log as it's grepped in integration tests for the heartbeat mechanism +http-server.log.enabled=true +http-server.log.immediate-flush=true diff --git a/tests/development_server.py b/tests/development_server.py index cf71663a..27bd3795 100644 --- a/tests/development_server.py +++ b/tests/development_server.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from pathlib import Path +import docker.errors from testcontainers.core.container import DockerContainer from testcontainers.core.network import Network from testcontainers.core.waiting_utils import wait_for_logs @@ -13,10 +14,36 @@ MINIO_ROOT_USER = "minio-access-key" MINIO_ROOT_PASSWORD = "minio-secret-key" +TRINO_IMAGE_REPO = "trinodb/trino" TRINO_VERSION = os.environ.get("TRINO_VERSION") or "latest" +TRINO_CONTAINER_NAME = "trino" TRINO_HOST = "localhost" +def get_trino_container(port: int): + """Find and return a running trino container. + Returns None if no matching container is found. + """ + client = docker.from_env() + try: + container = client.containers.get(TRINO_CONTAINER_NAME) + except docker.errors.NotFound: + return None + + if not any(tag.startswith(f"{TRINO_IMAGE_REPO}:") for tag in (container.image.tags or [])): + return None + + host_ports = [ + binding["HostPort"] + for bindings in container.ports.values() if bindings + for binding in bindings + ] + if str(port) not in host_ports: + return None + + return container + + def create_bucket(s3_client): bucket_name = "spooling" try: @@ -68,8 +95,8 @@ def start_development_server(port=None, trino_version=TRINO_VERSION): # create spooling bucket create_bucket(localstack.get_client("s3")) - trino = DockerContainer(f"trinodb/trino:{trino_version}") \ - .with_name("trino") \ + trino = DockerContainer(f"{TRINO_IMAGE_REPO}:{trino_version}") \ + .with_name(TRINO_CONTAINER_NAME) \ .with_network(network) \ .with_env("TRINO_CONFIG_DIR", "/etc/trino") \ .with_bind_ports(DEFAULT_PORT, port) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 53f86504..35ef0b30 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -27,6 +27,7 @@ from tzlocal import get_localzone_name # type: ignore import trino +from tests.development_server import get_trino_container from tests.integration.conftest import trino_version from trino import constants from trino.client import InlineSegment @@ -2000,6 +2001,63 @@ def test_spooled_segments_lazy_description(trino_connection): assert len(cur.fetchall()) == 60175 +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_heartbeat_head_requests_during_spooled_download(run_trino): + """Verify that heartbeat HEAD requests are sent to the coordinator while + downloading spooled segments from external storage.""" + host, port = run_trino + container = get_trino_container(port) + assert container, "Cannot find a running Trino container" + + conn = trino.dbapi.Connection( + host=host, port=port, user="test", source="test", + max_attempts=1, encoding="json", heartbeat_interval=0.1, + ) + + log_path = "/data/trino/var/log/http-request.log" + + # Capture the current size of the HTTP request log + exit_code, output = container.exec_run(["wc", "-l", log_path]) + assert exit_code == 0, f"Cannot read Trino HTTP request log, is the log path `{log_path}` correct?" + logfile_lines = int(output.decode().split()[0]) + + cur = conn.cursor() + cur.execute("""SELECT l.* + FROM tpch.tiny.lineitem l, TABLE(sequence( + start => 1, + stop => 5, + step => 1)) n""") + query_id = cur.query_id + cur.fetchall() + cur.close() + + head_request_found = False + # Sometimes trino needs time to flush the logs so we make few attempts + # to check the log with sleep inbetween. + for attempt in range(10): + if attempt: + t.sleep(1.0) + + _, output = container.exec_run(["tail", "-n", f"+{logfile_lines}", log_path]) + loglines = output.decode().splitlines() + + pattern = f"/v1/statement/executing/{query_id}/" + for line in loglines: + if "HEAD" in line and pattern in line: + head_request_found = True + break + + if head_request_found: + break + + assert head_request_found, ( + f"Expected heartbeat HEAD requests in http-request.log not found. Log tail:\n{''.join(loglines)}" + ) + + def get_cursor(legacy_prepared_statements, run_trino): host, port = run_trino diff --git a/tests/unit/test_client_spooling.py b/tests/unit/test_client_spooling.py new file mode 100644 index 00000000..e739cb2e --- /dev/null +++ b/tests/unit/test_client_spooling.py @@ -0,0 +1,202 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import time +from unittest import mock + +import pytest + +from trino.client import _RequestHeartbeat +from trino.client import ClientSession +from trino.client import DecodableSegment +from trino.client import InlineSegment +from trino.client import SegmentIterator +from trino.client import SpooledSegment +from trino.client import TrinoQuery +from trino.client import TrinoRequest + + +def _mock_trino_request(): + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + req._next_uri = "http://coordinator/v1/statement/q/1" + return req + + +def _head_response(status_code): + return mock.Mock(status_code=status_code, ok=(200 <= status_code < 300)) + + +@pytest.fixture +def ensure_max_failures_3(): + # Some tests assume _RequestHeartbeart.MAX_FAILURES is set to 3 + with mock.patch.object(_RequestHeartbeat, "MAX_FAILURES", 3): + yield + + +def test_heartbeat_sends_head_to_next_uri(): + req = _mock_trino_request() + with mock.patch.object(req, "head", return_value=_head_response(200)) as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + assert mock_head.call_count >= 2 + mock_head.assert_called_with(req.next_uri) + + +@pytest.mark.parametrize("status_code", (404, 405)) +def test_heartbeat_stops_on_404_405(status_code): + req = _mock_trino_request() + with mock.patch.object(req, "head", return_value=_head_response(status_code)) as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + # 404/405 means the server does not support heartbeat requests; they should stop after the first one + assert mock_head.call_count == 1 + + +def test_heartbeat_stops_after_max_failures_non_2xx(ensure_max_failures_3): + req = _mock_trino_request() + with mock.patch.object(req, "head", return_value=_head_response(500)) as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + assert mock_head.call_count == _RequestHeartbeat.MAX_FAILURES + + +def test_heartbeat_stops_after_max_failures_on_exception(ensure_max_failures_3): + req = _mock_trino_request() + with mock.patch.object(req, "head", side_effect=Exception("network error")) as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + assert mock_head.call_count == _RequestHeartbeat.MAX_FAILURES + + +def test_heartbeat_resets_failure_count_on_success(ensure_max_failures_3): + req = _mock_trino_request() + # Failure counter resets on 200 so the heartbeat keeps running past initial failures + responses = [_head_response(500), _head_response(500)] + [_head_response(200)] * 20 + with mock.patch.object(req, "head", side_effect=responses) as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + assert mock_head.call_count > _RequestHeartbeat.MAX_FAILURES + + +def test_heartbeat_skips_when_next_uri_is_none(): + req = _mock_trino_request() + req._next_uri = None + with mock.patch.object(req, "head") as mock_head: + with _RequestHeartbeat(req, interval=0.01): + time.sleep(0.1) + mock_head.assert_not_called() + + +def test_heartbeat_stop_is_immediate(): + req = _mock_trino_request() + with mock.patch.object(req, "head", return_value=_head_response(200)): + hb = _RequestHeartbeat(req, interval=30) + start = time.monotonic() + with hb: + pass + elapsed = time.monotonic() - start + assert elapsed < 1.0 + + +def _spooled_iterator(request, heartbeat_interval, rows=None): + """SegmentIterator with one SpooledSegment and a pre-set mock decoder.""" + segment = DecodableSegment("json", None, mock.Mock(spec=SpooledSegment)) + mapper = mock.Mock() + it = SegmentIterator([segment], mapper, request=request, heartbeat_interval=heartbeat_interval) + it._decoder = mock.Mock() + it._decoder.decode.return_value = rows if rows is not None else [[1, 2]] + return it + + +@pytest.mark.parametrize( + "trino_request, interval", + [(None, 1.0), (_mock_trino_request(), None), (_mock_trino_request(), 0.0)] +) +def test_iterator_value_error_when_only_request_or_heartbeat_interval_specified(trino_request, interval): + with pytest.raises(ValueError): + _ = _spooled_iterator(trino_request, interval) + + +def test_heartbeat_starts_during_spooled_segment_download(): + req = _mock_trino_request() + iterator = _spooled_iterator(req, heartbeat_interval=30.0) + with mock.patch("trino.client._RequestHeartbeat") as MockHB: + next(iterator) + MockHB.assert_called_once_with(req, 30.0) + # Make sure MockHB instance is used as a context manager + MockHB.return_value.__enter__.assert_called_once() + MockHB.return_value.__exit__.assert_called_once() + + +def test_no_heartbeat_for_inline_segment(): + segment = DecodableSegment("json", None, mock.Mock(spec=InlineSegment)) + mapper = mock.Mock() + iterator = SegmentIterator([segment], mapper, request=_mock_trino_request(), heartbeat_interval=30.0) + iterator._decoder = mock.Mock() + iterator._decoder.decode.return_value = [[1, 2]] + with mock.patch("trino.client._RequestHeartbeat") as MockHB: + next(iterator) + MockHB.assert_not_called() + + +@pytest.mark.parametrize("interval", (None, 0.0)) +def test_no_heartbeat_when_interval_none_or_zero(interval): + iterator = _spooled_iterator(request=None, heartbeat_interval=interval) + with mock.patch("trino.client._RequestHeartbeat") as MockHB: + next(iterator) + MockHB.assert_not_called() + + +def _spooled_fetch_response(): + """Minimal spooled protocol GET response JSON.""" + resp = mock.Mock() + resp.status_code = 200 + resp.ok = True + resp.headers = {} + resp.text = json.dumps({ + "id": "q1", + "infoUri": "http://coordinator/query.html?q1", + "stats": {"state": "FINISHED"}, + "data": { + "encoding": "json", + "segments": [ + { + "type": "inline", + "metadata": {"uncompressedSize": "10", "segmentSize": "10"}, + "data": "", + } + ], + }, + }) + return resp + + +@pytest.mark.parametrize("heartbeat_interval", (30.0, None)) +def test_fetch_passes_request_and_interval_to_segment_iterator(heartbeat_interval): + session = ClientSession(user="test", encoding="json", heartbeat_interval=heartbeat_interval) + req = TrinoRequest(host="coordinator", port=8080, client_session=session, http_scheme="http") + req._next_uri = "http://coordinator/v1/statement/q1/1" + query = TrinoQuery(req, query="SELECT 1") + query._row_mapper = mock.Mock() + + with mock.patch.object(req, "get", return_value=_spooled_fetch_response()): + with mock.patch("trino.client.SegmentIterator") as MockSI: + MockSI.return_value = iter([]) + query.fetch() + + assert MockSI.call_args.kwargs["request"] is req + assert MockSI.call_args.kwargs["heartbeat_interval"] == heartbeat_interval diff --git a/trino/client.py b/trino/client.py index 97ef11ae..692d10fa 100644 --- a/trino/client.py +++ b/trino/client.py @@ -194,6 +194,7 @@ def __init__( roles: Optional[Union[Dict[str, str], str]] = None, timezone: Optional[str] = None, encoding: Optional[Union[str, List[str]]] = None, + heartbeat_interval: Optional[float] = constants.DEFAULT_HEARTBEAT_INTERVAL, ): self._object_lock = threading.Lock() self._prepared_statements: Dict[str, str] = {} @@ -216,6 +217,7 @@ def __init__( from tzlocal import get_localzone_name self._timezone = get_localzone_name() self._encoding = encoding + self._heartbeat_interval = heartbeat_interval @property def user(self) -> str: @@ -316,6 +318,10 @@ def encoding(self) -> Optional[Union[str, List[str]]]: with self._object_lock: return self._encoding + @property + def heartbeat_interval(self) -> Optional[float]: + return self._heartbeat_interval + @staticmethod def _format_roles(roles: Union[Dict[str, str], str]) -> Dict[str, str]: if isinstance(roles, str): @@ -632,6 +638,7 @@ def max_attempts(self, value: int) -> None: self._get = self._http_session.get self._post = self._http_session.post self._delete = self._http_session.delete + self._head = self._http_session.head return with_retry = _retry_with( @@ -647,6 +654,7 @@ def max_attempts(self, value: int) -> None: self._get = with_retry(self._http_session.get) self._post = with_retry(self._http_session.post) self._delete = with_retry(self._http_session.delete) + self._head = with_retry(self._http_session.head) def get_url(self, path: str) -> str: return "{protocol}://{host}:{port}{path}".format( @@ -690,6 +698,14 @@ def get(self, url: str) -> Response: def delete(self, url: str) -> Response: return self._delete(url, timeout=self._request_timeout, proxies=PROXIES) + def head(self, url: str) -> Response: + return self._head( + url, + headers=self.http_headers, + timeout=self._request_timeout, + proxies=PROXIES, + ) + @staticmethod def _process_error(error, query_id: Optional[str]) -> Union[TrinoExternalError, TrinoQueryError, TrinoUserError]: error_type = error["errorType"] @@ -1003,7 +1019,12 @@ def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]: if self._fetch_mode == "segments": return spooled # Return iterator directly, do NOT materialize with list() - return SegmentIterator(spooled, self._row_mapper) + return SegmentIterator( + spooled, + self._row_mapper, + request=self._request, + heartbeat_interval=self._request._client_session.heartbeat_interval, + ) elif isinstance(status.rows, list): return self._row_mapper.map(rows) else: @@ -1274,14 +1295,77 @@ def __repr__(self): return (f"DecodableSegment(encoding={self._encoding}, metadata={self._metadata}, segment={self._segment})") +class _RequestHeartbeat: + """ + Heartbeat loop for a trino request. Periodically sends HEAD requests to the request's next URI. + This prevents the coordinator from abandoning a query if the client is silent for a longer + period of time, for example when downloading a spooled segment from an external storage. + """ + MAX_FAILURES = 3 + + def __init__(self, request: TrinoRequest, interval: float) -> None: + self._request = request + self._interval = interval + # The event for telling the heartbeat thread to exit + self._stop_event = threading.Event() + + def __enter__(self) -> _RequestHeartbeat: + threading.Thread(target=self._run, daemon=True).start() + return self + + def __exit__(self, *_) -> None: + self._stop_event.set() + + def _run(self) -> None: + """ + Run the heartbeat loop. + + Exit when the self._stop_event is set, the query completed + or if the error count exceeds _MAX_FAILURES. + """ + failures = 0 + + while not self._stop_event.wait(timeout=self._interval): + uri = self._request.next_uri + if uri is None: + return + + try: + response = self._request.head(uri) + if response.status_code in (404, 405): + logger.warning("The server does not support heartbeat calls") + return + if not response.ok: + failures += 1 + else: + failures = 0 + except Exception: + failures += 1 + + if failures >= self.MAX_FAILURES: + logger.warning(f"Stopping the heartbeat after {self.MAX_FAILURES} consecutive errors") + return + + class SegmentIterator: - def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None: + def __init__( + self, + segments: Union[DecodableSegment, List[DecodableSegment]], + mapper: RowMapper, + *, + request: Optional[TrinoRequest] = None, + heartbeat_interval: Optional[float] = None, + ) -> None: self._segments = iter(segments if isinstance(segments, List) else [segments]) self._mapper = mapper self._decoder = None self._rows: Iterator[List[List[Any]]] = iter([]) self._finished = False self._current_segment: Optional[DecodableSegment] = None + if (request is not None) != bool(heartbeat_interval): + raise ValueError("request and heartbeat_interval must be both provided or both omitted") + self._request = request + self._heartbeat_interval = heartbeat_interval def __iter__(self) -> Iterator[List[Any]]: return self @@ -1307,7 +1391,17 @@ def _load_next_segment(self): if self._decoder is None: self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper) .create(self._current_segment.encoding)) - self._rows = iter(self._decoder.decode(self._current_segment.segment)) + + if isinstance(self._current_segment.segment, SpooledSegment) and self._request and self._heartbeat_interval: + # Downloading a spooled segment may take some time. In the meantime, send heartbeat + # requests so the coordinator doesn't think we lost interest and close the query. + with _RequestHeartbeat(self._request, self._heartbeat_interval): + rows = self._decoder.decode(self._current_segment.segment) + else: + rows = self._decoder.decode(self._current_segment.segment) + + self._rows = iter(rows) + except StopIteration: self._finished = True diff --git a/trino/constants.py b/trino/constants.py index b136aaaf..ccabb397 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -20,6 +20,7 @@ DEFAULT_AUTH: Optional[Any] = None DEFAULT_MAX_ATTEMPTS = 3 DEFAULT_REQUEST_TIMEOUT: float = 30.0 +DEFAULT_HEARTBEAT_INTERVAL: float = 30.0 MAX_NT_PASSWORD_SIZE: int = 1280 HTTP = "http" diff --git a/trino/dbapi.py b/trino/dbapi.py index 42eeb547..3749d797 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -164,6 +164,7 @@ def __init__( roles=None, timezone=None, encoding: Union[str, List[str]] = _USE_DEFAULT_ENCODING, + heartbeat_interval: Optional[float] = constants.DEFAULT_HEARTBEAT_INTERVAL, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) @@ -194,6 +195,7 @@ def __init__( roles=roles, timezone=timezone, encoding=encoding, + heartbeat_interval=heartbeat_interval, ) # mypy cannot follow module import if http_session is None: