Skip to content

Commit b23558b

Browse files
committed
(improvement) unit tests for benchmarking query planning.
Not a very scientific one, but reasonable to get some measurements in terms of how different optimizations work. Example run (on scylladb#650 branch): ykaul@ykaul:~/github/python-driver$ pytest -s tests/unit/test_policy_performance.py /usr/lib/python3.14/site-packages/pytest_asyncio/plugin.py:211: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ============================================================================================================ test session starts ============================================================================================================= platform linux -- Python 3.14.2, pytest-8.3.5, pluggy-1.6.0 rootdir: /home/ykaul/github/python-driver configfile: pyproject.toml plugins: asyncio-1.1.0, anyio-4.12.1 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 4 items tests/unit/test_policy_performance.py Pinned to CPU 0 .... === Performance Benchmarks === Policy | Ops | Time (s) | Kops/s ---------------------------------------------------------------------- DCAware | 100000 | 0.2328 | 429 RackAware | 100000 | 0.3637 | 274 TokenAware(DCAware) | 100000 | 1.5884 | 62 TokenAware(RackAware) | 100000 | 1.6816 | 59 ---------------------------------------------------------------------- Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 711a7eb commit b23558b

1 file changed

Lines changed: 214 additions & 0 deletions

File tree

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import unittest
2+
import time
3+
import uuid
4+
import struct
5+
import os
6+
import statistics
7+
from unittest.mock import Mock
8+
9+
from cassandra.policies import (
10+
DCAwareRoundRobinPolicy,
11+
RackAwareRoundRobinPolicy,
12+
TokenAwarePolicy,
13+
DefaultLoadBalancingPolicy,
14+
HostFilterPolicy
15+
)
16+
from cassandra.pool import Host
17+
from cassandra.cluster import SimpleConvictionPolicy
18+
19+
# Mock for Connection/EndPoint since Host expects it
20+
class MockEndPoint(object):
21+
__slots__ = ('address',)
22+
23+
def __init__(self, address):
24+
self.address = address
25+
def __str__(self):
26+
return self.address
27+
28+
class MockStatement(object):
29+
__slots__ = ('routing_key', 'keyspace', 'table')
30+
31+
def __init__(self, routing_key, keyspace="ks", table="tbl"):
32+
self.routing_key = routing_key
33+
self.keyspace = keyspace
34+
self.table = table
35+
36+
def is_lwt(self):
37+
return False
38+
39+
class MockTokenMap(object):
40+
__slots__ = ('token_class', 'get_replicas_func')
41+
def __init__(self, get_replicas_func):
42+
self.token_class = Mock()
43+
self.token_class.from_key = lambda k: k
44+
self.get_replicas_func = get_replicas_func
45+
46+
def get_replicas(self, keyspace, token):
47+
return self.get_replicas_func(keyspace, token)
48+
49+
class MockTablets(object):
50+
__slots__ = ()
51+
def get_tablet_for_key(self, keyspace, table, key):
52+
return None
53+
54+
class MockMetadata(object):
55+
__slots__ = ('_tablets', 'token_map', 'get_replicas_func', 'hosts_by_address')
56+
def __init__(self, get_replicas_func, hosts_by_address):
57+
self._tablets = MockTablets()
58+
self.token_map = MockTokenMap(get_replicas_func)
59+
self.get_replicas_func = get_replicas_func
60+
self.hosts_by_address = hosts_by_address
61+
62+
def can_support_partitioner(self):
63+
return True
64+
65+
def get_replicas(self, keyspace, key):
66+
return self.get_replicas_func(keyspace, key)
67+
68+
def get_host(self, addr):
69+
return self.hosts_by_address.get(addr)
70+
71+
class MockCluster(object):
72+
__slots__ = ('metadata',)
73+
def __init__(self, metadata):
74+
self.metadata = metadata
75+
76+
class TestPolicyPerformance(unittest.TestCase):
77+
@classmethod
78+
def setUpClass(cls):
79+
if hasattr(os, 'sched_setaffinity'):
80+
try:
81+
# Pin to the first available CPU
82+
cpu = list(os.sched_getaffinity(0))[0]
83+
os.sched_setaffinity(0, {cpu})
84+
print(f"Pinned to CPU {cpu}")
85+
except Exception as e:
86+
print(f"Could not pin CPU: {e}")
87+
88+
# 1. Topology: 5 DCs, 3 Racks/DC, 3 Nodes/Rack = 45 Nodes
89+
cls.hosts = []
90+
cls.hosts_map = {} # host_id -> Host
91+
cls.replicas_map = {} # routing_key -> list of replica hosts
92+
93+
# Deterministic generation
94+
dcs = ['dc{}'.format(i) for i in range(5)]
95+
racks = ['rack{}'.format(i) for i in range(3)]
96+
nodes_per_rack = 3
97+
98+
ip_counter = 0
99+
subnet_counter = 0
100+
for dc in dcs:
101+
for rack in racks:
102+
subnet_counter += 1
103+
for node_idx in range(nodes_per_rack):
104+
ip_counter += 1
105+
address = "127.0.{}.{}".format(subnet_counter, node_idx + 1)
106+
h_id = uuid.UUID(int=ip_counter)
107+
h = Host(MockEndPoint(address), SimpleConvictionPolicy, host_id=h_id)
108+
h.set_location_info(dc, rack)
109+
cls.hosts.append(h)
110+
cls.hosts_map[h_id] = h
111+
112+
# 2. Queries: 100,000 deterministic queries
113+
cls.query_count = 100000
114+
cls.queries = []
115+
cls.results = []
116+
# We'll use simple packed integers as routing keys
117+
for i in range(cls.query_count):
118+
key = struct.pack('>I', i)
119+
cls.queries.append(MockStatement(routing_key=key))
120+
121+
# Pre-calculate replicas for TokenAware:
122+
# Deterministically pick 3 replicas based on the key index
123+
# This simulates the metadata.get_replicas behavior
124+
# We pick index i, i+1, i+2 mod 45
125+
replicas = []
126+
for r in range(3):
127+
idx = (i + r) % len(cls.hosts)
128+
replicas.append(cls.hosts[idx])
129+
cls.replicas_map[key] = replicas
130+
131+
def _get_replicas_side_effect(self, keyspace, key):
132+
return self.replicas_map.get(key, [])
133+
134+
def _setup_cluster_mock(self):
135+
hosts_by_address = {}
136+
for host in self.hosts:
137+
addr = getattr(host, 'address', None)
138+
if addr is None and getattr(host, 'endpoint', None) is not None:
139+
addr = getattr(host.endpoint, 'address', None)
140+
if addr is not None:
141+
hosts_by_address[addr] = host
142+
metadata = MockMetadata(self._get_replicas_side_effect, hosts_by_address)
143+
return MockCluster(metadata)
144+
145+
def _run_benchmark(self, name, policy):
146+
# Setup
147+
cluster = self._setup_cluster_mock()
148+
policy.populate(cluster, self.hosts)
149+
150+
# Warmup
151+
for _ in range(100):
152+
list(policy.make_query_plan(working_keyspace="ks", query=self.queries[0]))
153+
154+
# Run multiple iterations to reduce noise
155+
iterations = 5
156+
timings = []
157+
158+
for _ in range(iterations):
159+
start_time = time.perf_counter()
160+
for q in self.queries:
161+
# We consume the iterator to ensure full plan generation cost is paid
162+
for _ in policy.make_query_plan(working_keyspace="ks", query=q):
163+
pass
164+
end_time = time.perf_counter()
165+
timings.append(end_time - start_time)
166+
167+
# Use median to filter outliers
168+
duration = statistics.median(timings)
169+
170+
count = len(self.queries)
171+
ops_per_sec = count / duration
172+
kops = int(ops_per_sec / 1000)
173+
174+
self.results.append((name, count, duration, kops))
175+
return ops_per_sec
176+
177+
@classmethod
178+
def tearDownClass(cls):
179+
print("\n\n=== Performance Benchmarks ===")
180+
print(f"{'Policy':<30} | {'Ops':<10} | {'Time (s)':<10} | {'Kops/s':<10}")
181+
print("-" * 70)
182+
for name, count, duration, kops in cls.results:
183+
print(f"{name:<30} | {count:<10} | {duration:<10.4f} | {kops:<10}")
184+
print("-" * 70)
185+
186+
def test_dc_aware(self):
187+
# Local DC = dc0, 1 remote host per DC
188+
policy = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
189+
self._run_benchmark("DCAware", policy)
190+
191+
def test_rack_aware(self):
192+
# Local DC = dc0, Local Rack = rack0, 1 remote host per DC
193+
policy = RackAwareRoundRobinPolicy(local_dc='dc0', local_rack='rack0', used_hosts_per_remote_dc=1)
194+
self._run_benchmark("RackAware", policy)
195+
196+
def test_token_aware_wrapping_dc_aware(self):
197+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
198+
policy = TokenAwarePolicy(child, shuffle_replicas=False) # False for strict determinism in test if needed
199+
self._run_benchmark("TokenAware(DCAware)", policy)
200+
201+
def test_token_aware_wrapping_rack_aware(self):
202+
child = RackAwareRoundRobinPolicy(local_dc='dc0', local_rack='rack0', used_hosts_per_remote_dc=1)
203+
policy = TokenAwarePolicy(child, shuffle_replicas=False)
204+
self._run_benchmark("TokenAware(RackAware)", policy)
205+
206+
def test_default_wrapping_dc_aware(self):
207+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
208+
policy = DefaultLoadBalancingPolicy(child)
209+
self._run_benchmark("Default(DCAware)", policy)
210+
211+
def test_host_filter_wrapping_dc_aware(self):
212+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
213+
policy = HostFilterPolicy(child_policy=child, predicate=lambda host: host.rack != 'rack2')
214+
self._run_benchmark("HostFilter(DCAware)", policy)

0 commit comments

Comments
 (0)