Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions etc/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ protocol.spooling.retrieval-mode=coordinator_proxy
# Enable dynamic catalog management
catalog.management=dynamic

# Disable http request log
http-server.log.enabled=false
# Enable HTTP request log as it's grepped in integration tests for the heartbeat mechanism
http-server.log.enabled=true
http-server.log.immediate-flush=true
31 changes: 29 additions & 2 deletions tests/development_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import contextmanager
from pathlib import Path

import docker.errors
from testcontainers.core.container import DockerContainer
from testcontainers.core.network import Network
from testcontainers.core.waiting_utils import wait_for_logs
Expand All @@ -13,10 +14,36 @@
MINIO_ROOT_USER = "minio-access-key"
MINIO_ROOT_PASSWORD = "minio-secret-key"

TRINO_IMAGE_REPO = "trinodb/trino"
TRINO_VERSION = os.environ.get("TRINO_VERSION") or "latest"
TRINO_CONTAINER_NAME = "trino"
TRINO_HOST = "localhost"


def get_trino_container(port: int):
"""Find and return a running trino container.
Returns None if no matching container is found.
"""
client = docker.from_env()
try:
container = client.containers.get(TRINO_CONTAINER_NAME)
except docker.errors.NotFound:
return None

if not any(tag.startswith(f"{TRINO_IMAGE_REPO}:") for tag in (container.image.tags or [])):
return None

host_ports = [
binding["HostPort"]
for bindings in container.ports.values() if bindings
for binding in bindings
]
if str(port) not in host_ports:
return None

return container


def create_bucket(s3_client):
bucket_name = "spooling"
try:
Expand Down Expand Up @@ -68,8 +95,8 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
# create spooling bucket
create_bucket(localstack.get_client("s3"))

trino = DockerContainer(f"trinodb/trino:{trino_version}") \
.with_name("trino") \
trino = DockerContainer(f"{TRINO_IMAGE_REPO}:{trino_version}") \
.with_name(TRINO_CONTAINER_NAME) \
.with_network(network) \
.with_env("TRINO_CONFIG_DIR", "/etc/trino") \
.with_bind_ports(DEFAULT_PORT, port)
Expand Down
58 changes: 58 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tzlocal import get_localzone_name # type: ignore

import trino
from tests.development_server import get_trino_container
from tests.integration.conftest import trino_version
from trino import constants
from trino.client import InlineSegment
Expand Down Expand Up @@ -2000,6 +2001,63 @@ def test_spooled_segments_lazy_description(trino_connection):
assert len(cur.fetchall()) == 60175


@pytest.mark.skipif(
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_heartbeat_head_requests_during_spooled_download(run_trino):
"""Verify that heartbeat HEAD requests are sent to the coordinator while
downloading spooled segments from external storage."""
host, port = run_trino
container = get_trino_container(port)
assert container, "Cannot find a running Trino container"

conn = trino.dbapi.Connection(
host=host, port=port, user="test", source="test",
max_attempts=1, encoding="json", heartbeat_interval=0.1,
)

log_path = "/data/trino/var/log/http-request.log"

# Capture the current size of the HTTP request log
exit_code, output = container.exec_run(["wc", "-l", log_path])
assert exit_code == 0, f"Cannot read Trino HTTP request log, is the log path `{log_path}` correct?"
logfile_lines = int(output.decode().split()[0])

cur = conn.cursor()
cur.execute("""SELECT l.*
FROM tpch.tiny.lineitem l, TABLE(sequence(
start => 1,
stop => 5,
step => 1)) n""")
query_id = cur.query_id
cur.fetchall()
cur.close()

head_request_found = False
# Sometimes trino needs time to flush the logs so we make few attempts
# to check the log with sleep inbetween.
for attempt in range(10):
if attempt:
t.sleep(1.0)

_, output = container.exec_run(["tail", "-n", f"+{logfile_lines}", log_path])
loglines = output.decode().splitlines()

pattern = f"/v1/statement/executing/{query_id}/"
for line in loglines:
if "HEAD" in line and pattern in line:
head_request_found = True
break

if head_request_found:
break

assert head_request_found, (
f"Expected heartbeat HEAD requests in http-request.log not found. Log tail:\n{''.join(loglines)}"
)


def get_cursor(legacy_prepared_statements, run_trino):
host, port = run_trino

Expand Down
202 changes: 202 additions & 0 deletions tests/unit/test_client_spooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from unittest import mock

import pytest

from trino.client import _RequestHeartbeat
from trino.client import ClientSession
from trino.client import DecodableSegment
from trino.client import InlineSegment
from trino.client import SegmentIterator
from trino.client import SpooledSegment
from trino.client import TrinoQuery
from trino.client import TrinoRequest


def _mock_trino_request():
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
req._next_uri = "http://coordinator/v1/statement/q/1"
return req


def _head_response(status_code):
return mock.Mock(status_code=status_code, ok=(200 <= status_code < 300))


@pytest.fixture
def ensure_max_failures_3():
# Some tests assume _RequestHeartbeart.MAX_FAILURES is set to 3
with mock.patch.object(_RequestHeartbeat, "MAX_FAILURES", 3):
yield


def test_heartbeat_sends_head_to_next_uri():
req = _mock_trino_request()
with mock.patch.object(req, "head", return_value=_head_response(200)) as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
assert mock_head.call_count >= 2
mock_head.assert_called_with(req.next_uri)


@pytest.mark.parametrize("status_code", (404, 405))
def test_heartbeat_stops_on_404_405(status_code):
req = _mock_trino_request()
with mock.patch.object(req, "head", return_value=_head_response(status_code)) as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
# 404/405 means the server does not support heartbeat requests; they should stop after the first one
assert mock_head.call_count == 1


def test_heartbeat_stops_after_max_failures_non_2xx(ensure_max_failures_3):
req = _mock_trino_request()
with mock.patch.object(req, "head", return_value=_head_response(500)) as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
assert mock_head.call_count == _RequestHeartbeat.MAX_FAILURES


def test_heartbeat_stops_after_max_failures_on_exception(ensure_max_failures_3):
req = _mock_trino_request()
with mock.patch.object(req, "head", side_effect=Exception("network error")) as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
assert mock_head.call_count == _RequestHeartbeat.MAX_FAILURES


def test_heartbeat_resets_failure_count_on_success(ensure_max_failures_3):
req = _mock_trino_request()
# Failure counter resets on 200 so the heartbeat keeps running past initial failures
responses = [_head_response(500), _head_response(500)] + [_head_response(200)] * 20
with mock.patch.object(req, "head", side_effect=responses) as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
assert mock_head.call_count > _RequestHeartbeat.MAX_FAILURES


def test_heartbeat_skips_when_next_uri_is_none():
req = _mock_trino_request()
req._next_uri = None
with mock.patch.object(req, "head") as mock_head:
with _RequestHeartbeat(req, interval=0.01):
time.sleep(0.1)
mock_head.assert_not_called()


def test_heartbeat_stop_is_immediate():
req = _mock_trino_request()
with mock.patch.object(req, "head", return_value=_head_response(200)):
hb = _RequestHeartbeat(req, interval=30)
start = time.monotonic()
with hb:
pass
elapsed = time.monotonic() - start
assert elapsed < 1.0


def _spooled_iterator(request, heartbeat_interval, rows=None):
"""SegmentIterator with one SpooledSegment and a pre-set mock decoder."""
segment = DecodableSegment("json", None, mock.Mock(spec=SpooledSegment))
mapper = mock.Mock()
it = SegmentIterator([segment], mapper, request=request, heartbeat_interval=heartbeat_interval)
it._decoder = mock.Mock()
it._decoder.decode.return_value = rows if rows is not None else [[1, 2]]
return it


@pytest.mark.parametrize(
"trino_request, interval",
[(None, 1.0), (_mock_trino_request(), None), (_mock_trino_request(), 0.0)]
)
def test_iterator_value_error_when_only_request_or_heartbeat_interval_specified(trino_request, interval):
with pytest.raises(ValueError):
_ = _spooled_iterator(trino_request, interval)


def test_heartbeat_starts_during_spooled_segment_download():
req = _mock_trino_request()
iterator = _spooled_iterator(req, heartbeat_interval=30.0)
with mock.patch("trino.client._RequestHeartbeat") as MockHB:
next(iterator)
MockHB.assert_called_once_with(req, 30.0)
# Make sure MockHB instance is used as a context manager
MockHB.return_value.__enter__.assert_called_once()
MockHB.return_value.__exit__.assert_called_once()


def test_no_heartbeat_for_inline_segment():
segment = DecodableSegment("json", None, mock.Mock(spec=InlineSegment))
mapper = mock.Mock()
iterator = SegmentIterator([segment], mapper, request=_mock_trino_request(), heartbeat_interval=30.0)
iterator._decoder = mock.Mock()
iterator._decoder.decode.return_value = [[1, 2]]
with mock.patch("trino.client._RequestHeartbeat") as MockHB:
next(iterator)
MockHB.assert_not_called()


@pytest.mark.parametrize("interval", (None, 0.0))
def test_no_heartbeat_when_interval_none_or_zero(interval):
iterator = _spooled_iterator(request=None, heartbeat_interval=interval)
with mock.patch("trino.client._RequestHeartbeat") as MockHB:
next(iterator)
MockHB.assert_not_called()


def _spooled_fetch_response():
"""Minimal spooled protocol GET response JSON."""
resp = mock.Mock()
resp.status_code = 200
resp.ok = True
resp.headers = {}
resp.text = json.dumps({
"id": "q1",
"infoUri": "http://coordinator/query.html?q1",
"stats": {"state": "FINISHED"},
"data": {
"encoding": "json",
"segments": [
{
"type": "inline",
"metadata": {"uncompressedSize": "10", "segmentSize": "10"},
"data": "",
}
],
},
})
return resp


@pytest.mark.parametrize("heartbeat_interval", (30.0, None))
def test_fetch_passes_request_and_interval_to_segment_iterator(heartbeat_interval):
session = ClientSession(user="test", encoding="json", heartbeat_interval=heartbeat_interval)
req = TrinoRequest(host="coordinator", port=8080, client_session=session, http_scheme="http")
req._next_uri = "http://coordinator/v1/statement/q1/1"
query = TrinoQuery(req, query="SELECT 1")
query._row_mapper = mock.Mock()

with mock.patch.object(req, "get", return_value=_spooled_fetch_response()):
with mock.patch("trino.client.SegmentIterator") as MockSI:
MockSI.return_value = iter([])
query.fetch()

assert MockSI.call_args.kwargs["request"] is req
assert MockSI.call_args.kwargs["heartbeat_interval"] == heartbeat_interval
Loading
Loading