Skip to content

Commit fad111c

Browse files
authored
Merge pull request #763 from r142f/local-dc-detection
feat: add nearest DC detection with TCP race
1 parent c860928 commit fad111c

9 files changed

Lines changed: 1078 additions & 3 deletions
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# -*- coding: utf-8 -*-
2+
import pytest
3+
from unittest.mock import MagicMock, patch, AsyncMock
4+
from ydb import driver, connection
5+
from ydb.aio import pool, _utilities
6+
7+
8+
class MockEndpointInfo:
9+
def __init__(self, address, port, location):
10+
self.address = address
11+
self.port = port
12+
self.endpoint = f"{address}:{port}"
13+
self.location = location
14+
self.ssl = False
15+
self.node_id = 1
16+
17+
def endpoints_with_options(self):
18+
yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id))
19+
20+
21+
class MockDiscoveryResult:
22+
def __init__(self, self_location, endpoints):
23+
self.self_location = self_location
24+
self.endpoints = endpoints
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_detect_local_dc_overrides_server_location():
29+
"""Test that detected location overrides server's self_location for preferred endpoints."""
30+
# Server reports dc1, but we detect dc2 as nearest
31+
endpoints = [
32+
MockEndpointInfo("dc1-host", 2135, "dc1"),
33+
MockEndpointInfo("dc2-host", 2135, "dc2"),
34+
]
35+
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)
36+
37+
mock_resolver = MagicMock()
38+
mock_resolver.resolve = AsyncMock(return_value=mock_result)
39+
40+
preferred = []
41+
42+
def mock_init(self, endpoint, driver_config, endpoint_options=None):
43+
self.endpoint = endpoint
44+
self.node_id = 1
45+
46+
with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value="dc2")):
47+
with patch("ydb.aio.connection.Connection.__init__", mock_init):
48+
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
49+
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
50+
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
51+
config = driver.DriverConfig(
52+
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
53+
)
54+
discovery = pool.Discovery(
55+
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
56+
)
57+
discovery._resolver = mock_resolver
58+
59+
original_add = discovery._cache.add
60+
discovery._cache.add = lambda conn, pref=False: (
61+
preferred.append(conn.endpoint) if pref else None,
62+
original_add(conn, pref),
63+
)[1]
64+
65+
await discovery.execute_discovery()
66+
67+
assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)"
68+
assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred"
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_detect_local_dc_failure_fallback():
73+
"""Test that detection failure falls back to server's self_location."""
74+
endpoints = [
75+
MockEndpointInfo("dc1-host", 2135, "dc1"),
76+
MockEndpointInfo("dc2-host", 2135, "dc2"),
77+
]
78+
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)
79+
80+
mock_resolver = MagicMock()
81+
mock_resolver.resolve = AsyncMock(return_value=mock_result)
82+
83+
preferred = []
84+
85+
def mock_init(self, endpoint, driver_config, endpoint_options=None):
86+
self.endpoint = endpoint
87+
self.node_id = 1
88+
89+
with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value=None)):
90+
with patch("ydb.aio.connection.Connection.__init__", mock_init):
91+
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
92+
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
93+
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
94+
config = driver.DriverConfig(
95+
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False
96+
)
97+
discovery = pool.Discovery(
98+
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
99+
)
100+
discovery._resolver = mock_resolver
101+
102+
original_add = discovery._cache.add
103+
discovery._cache.add = lambda conn, pref=False: (
104+
preferred.append(conn.endpoint) if pref else None,
105+
original_add(conn, pref),
106+
)[1]
107+
108+
await discovery.execute_discovery()
109+
110+
assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)"
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_detect_local_dc_skipped_when_use_all_nodes_true():
115+
"""Test that detect_local_dc is NOT called when use_all_nodes=True."""
116+
endpoints = [
117+
MockEndpointInfo("dc1-host", 2135, "dc1"),
118+
MockEndpointInfo("dc2-host", 2135, "dc2"),
119+
]
120+
mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints)
121+
122+
mock_resolver = MagicMock()
123+
mock_resolver.resolve = AsyncMock(return_value=mock_result)
124+
125+
def mock_init(self, endpoint, driver_config, endpoint_options=None):
126+
self.endpoint = endpoint
127+
self.node_id = 1
128+
129+
with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock:
130+
with patch("ydb.aio.connection.Connection.__init__", mock_init):
131+
with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()):
132+
with patch("ydb.aio.connection.Connection.close", AsyncMock()):
133+
with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None):
134+
config = driver.DriverConfig(
135+
endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True
136+
)
137+
discovery = pool.Discovery(
138+
store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config
139+
)
140+
discovery._resolver = mock_resolver
141+
await discovery.execute_discovery()
142+
143+
assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True"

tests/aio/test_nearest_dc.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import asyncio
2+
import pytest
3+
from ydb.aio import _utilities
4+
5+
6+
class MockEndpoint:
7+
def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()):
8+
self.address = address
9+
self.port = port
10+
self.endpoint = f"{address}:{port}"
11+
self.location = location
12+
self.ipv4_addrs = ipv4_addrs
13+
self.ipv6_addrs = ipv6_addrs
14+
15+
16+
class MockWriter:
17+
def __init__(self):
18+
self.closed = False
19+
20+
def close(self):
21+
self.closed = True
22+
23+
async def wait_closed(self):
24+
await asyncio.sleep(0)
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_check_fastest_endpoint_empty():
29+
assert await _utilities._check_fastest_endpoint([]) is None
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_check_fastest_endpoint_all_fail(monkeypatch):
34+
async def fake_open_connection(host, port):
35+
raise OSError("connect failed")
36+
37+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
38+
39+
endpoints = [
40+
MockEndpoint("a", 1, "dc1"),
41+
MockEndpoint("b", 1, "dc2"),
42+
]
43+
assert await _utilities._check_fastest_endpoint(endpoints, timeout=0.05) is None
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_check_fastest_endpoint_fastest_wins(monkeypatch):
48+
async def fake_open_connection(host, port):
49+
if host == "slow":
50+
await asyncio.sleep(0.05)
51+
return None, MockWriter()
52+
53+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
54+
55+
endpoints = [
56+
MockEndpoint("slow", 1, "dc_slow"),
57+
MockEndpoint("fast", 1, "dc_fast"),
58+
]
59+
winner = await _utilities._check_fastest_endpoint(endpoints, timeout=0.2)
60+
assert winner is not None
61+
assert winner.location == "dc_fast"
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch):
66+
async def fake_open_connection(host, port):
67+
await asyncio.sleep(0.2)
68+
return None, MockWriter()
69+
70+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
71+
72+
endpoints = [
73+
MockEndpoint("hang1", 1, "dc1"),
74+
MockEndpoint("hang2", 1, "dc2"),
75+
]
76+
77+
winner = await _utilities._check_fastest_endpoint(endpoints, timeout=0.05)
78+
79+
assert winner is None
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_detect_local_dc_empty_endpoints():
84+
with pytest.raises(ValueError, match="Empty endpoints"):
85+
await _utilities.detect_local_dc([])
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_detect_local_dc_single_location_returns_immediately(monkeypatch):
90+
async def fail_if_called(*args, **kwargs):
91+
raise AssertionError("open_connection should not be called for single location")
92+
93+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fail_if_called)
94+
95+
endpoints = [
96+
MockEndpoint("h1", 1, "dc1"),
97+
MockEndpoint("h2", 1, "dc1"),
98+
]
99+
assert await _utilities.detect_local_dc(endpoints) == "dc1"
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch):
104+
async def fake_open_connection(host, port):
105+
raise OSError("connect failed")
106+
107+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
108+
109+
endpoints = [
110+
MockEndpoint("bad1", 9999, "dc1"),
111+
MockEndpoint("bad2", 9999, "dc2"),
112+
]
113+
assert await _utilities.detect_local_dc(endpoints, timeout=0.05) is None
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_detect_local_dc_returns_location_of_fastest(monkeypatch):
118+
async def fake_open_connection(host, port):
119+
if host == "dc1_host":
120+
await asyncio.sleep(0.05)
121+
return None, MockWriter()
122+
123+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
124+
125+
endpoints = [
126+
MockEndpoint("dc1_host", 1, "dc1"),
127+
MockEndpoint("dc2_host", 1, "dc2"),
128+
]
129+
assert await _utilities.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2"
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_detect_local_dc_respects_max_per_location(monkeypatch):
134+
calls = []
135+
136+
async def fake_open_connection(host, port):
137+
calls.append((host, port))
138+
raise OSError("connect failed")
139+
140+
monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection)
141+
142+
endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [
143+
MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5)
144+
]
145+
await _utilities.detect_local_dc(endpoints, max_per_location=2, timeout=0.2)
146+
147+
assert len(calls) == 4
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_detect_local_dc_validates_max_per_location():
152+
endpoints = [MockEndpoint("h1", 1, "dc1")]
153+
with pytest.raises(ValueError, match="max_per_location must be >= 1"):
154+
await _utilities.detect_local_dc(endpoints, max_per_location=0)
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_detect_local_dc_validates_timeout():
159+
endpoints = [MockEndpoint("h1", 1, "dc1")]
160+
with pytest.raises(ValueError, match="timeout must be > 0"):
161+
await _utilities.detect_local_dc(endpoints, timeout=0)

0 commit comments

Comments
 (0)