From 4ff5f0b80f7481ec3a3f6a2ea46e67f0af7ad808 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 31 Jan 2026 15:56:45 -0400 Subject: [PATCH 1/2] Remove scales dependency with self-contained metrics implementation The scales library had its last release in 2015 and was only tested on Python 2.7/3.3. This change replaces it with a self-contained metrics implementation that provides the same functionality: - IntStat: Thread-safe integer counter - PmfStat: Percentile/distribution statistics with reservoir sampling - Stat: Gauge statistics with callable evaluation - StatsCollection: Named collections of statistics - Global registry for metrics access via getStats() The new implementation maintains API compatibility with the existing metrics interface used by the driver. Fixes #665 --- benchmarks/base.py | 4 +- cassandra/metrics.py | 358 ++++++++++++++++++--- docs/installation.rst | 9 - docs/pyproject.toml | 1 - examples/request_init_listener.py | 10 +- pyproject.toml | 1 - tests/integration/standard/test_metrics.py | 20 +- tests/unit/test_metrics.py | 304 +++++++++++++++++ 8 files changed, 639 insertions(+), 68 deletions(-) create mode 100644 tests/unit/test_metrics.py diff --git a/benchmarks/base.py b/benchmarks/base.py index 2000b4069f..d9cd004474 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -21,7 +21,7 @@ from optparse import OptionParser import uuid -from greplin import scales +from cassandra.metrics import getStats dirname = os.path.dirname(os.path.abspath(__file__)) sys.path.append(dirname) @@ -192,7 +192,7 @@ def benchmark(thread_class): log.info("Total time: %0.2fs" % total) log.info("Average throughput: %0.2f/sec" % (options.num_ops / total)) if options.enable_metrics: - stats = scales.getStats()['cassandra'] + stats = getStats()['cassandra'] log.info("Connection errors: %d", stats['connection_errors']) log.info("Write timeouts: %d", stats['write_timeouts']) log.info("Read timeouts: %d", stats['read_timeouts']) diff --git a/cassandra/metrics.py b/cassandra/metrics.py index abfc863b55..6a1af793ec 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -12,19 +12,295 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Driver metrics collection module. + +This module provides metrics collection functionality without external dependencies. +It was originally based on the `scales` library but now uses a self-contained +implementation. +""" + from itertools import chain import logging - -try: - from greplin import scales -except ImportError: - raise ImportError( - "The scales library is required for metrics support: " - "https://pypi.org/project/scales/") +import math +import random +import threading log = logging.getLogger(__name__) +# Global stats registry +_stats_registry = {} +_registry_lock = threading.Lock() + + +def getStats(): + """ + Returns a copy of all registered stats. + """ + with _registry_lock: + return {name: stats._get_stats_dict() for name, stats in _stats_registry.items()} + + +class IntStat: + """ + A thread-safe integer counter statistic. + """ + __slots__ = ('name', '_value', '_lock') + + def __init__(self, name): + self.name = name + self._value = 0 + self._lock = threading.Lock() + + def __iadd__(self, other): + with self._lock: + self._value += other + return self + + def __int__(self): + return self._value + + def __repr__(self): + return f"IntStat({self.name}={self._value})" + + @property + def value(self): + return self._value + + +class Stat: + """ + A gauge statistic that evaluates a callable on access. + """ + __slots__ = ('name', '_func') + + def __init__(self, name, func): + self.name = name + self._func = func + + @property + def value(self): + return self._func() + + def __repr__(self): + return f"Stat({self.name}={self.value})" + + +class PmfStat: + """ + A probability mass function statistic that tracks timing/size distributions. + + Computes count, min, max, mean, stddev, median and various percentiles. + Uses reservoir sampling to limit memory usage for large sample counts. + """ + __slots__ = ('name', '_values', '_lock', '_count', '_min', '_max', '_sum', '_sum_sq') + + # Maximum number of values to retain for percentile calculations + _MAX_SAMPLES = 10000 + + def __init__(self, name): + self.name = name + self._values = [] + self._lock = threading.Lock() + self._count = 0 + self._min = float('inf') + self._max = float('-inf') + self._sum = 0.0 + self._sum_sq = 0.0 + + def addValue(self, value): + """Record a new value.""" + with self._lock: + self._count += 1 + self._sum += value + self._sum_sq += value * value + + if value < self._min: + self._min = value + if value > self._max: + self._max = value + + # Reservoir sampling for percentiles + if len(self._values) < self._MAX_SAMPLES: + self._values.append(value) + else: + # Replace random element with decreasing probability + idx = random.randint(0, self._count - 1) + if idx < self._MAX_SAMPLES: + self._values[idx] = value + + def _percentile(self, sorted_values, p): + """Calculate the p-th percentile from sorted values.""" + if not sorted_values: + return 0.0 + k = (len(sorted_values) - 1) * p / 100.0 + f = math.floor(k) + c = math.ceil(k) + if f == c: + return sorted_values[int(k)] + return sorted_values[int(f)] * (c - k) + sorted_values[int(c)] * (k - f) + + def _get_stats(self): + """Calculate all statistics.""" + with self._lock: + count = self._count + if count == 0: + return { + 'count': 0, + 'min': 0.0, + 'max': 0.0, + 'mean': 0.0, + 'stddev': 0.0, + 'median': 0.0, + '75percentile': 0.0, + '95percentile': 0.0, + '98percentile': 0.0, + '99percentile': 0.0, + '999percentile': 0.0, + } + + mean = self._sum / count + + # Calculate stddev using Welford's algorithm values + variance = (self._sum_sq / count) - (mean * mean) + stddev = math.sqrt(max(0, variance)) # max to handle floating point errors + + sorted_values = sorted(self._values) + + return { + 'count': count, + 'min': self._min, + 'max': self._max, + 'mean': mean, + 'stddev': stddev, + 'median': self._percentile(sorted_values, 50), + '75percentile': self._percentile(sorted_values, 75), + '95percentile': self._percentile(sorted_values, 95), + '98percentile': self._percentile(sorted_values, 98), + '99percentile': self._percentile(sorted_values, 99), + '999percentile': self._percentile(sorted_values, 99.9), + } + + def __getitem__(self, key): + return self._get_stats()[key] + + def __iter__(self): + return iter(self._get_stats()) + + def keys(self): + return self._get_stats().keys() + + def items(self): + return self._get_stats().items() + + def values(self): + return self._get_stats().values() + + def __repr__(self): + return f"PmfStat({self.name}, count={self._count})" + + +class StatsCollection: + """ + A named collection of statistics. + """ + __slots__ = ('_name', '_stats', '_int_stats', '_pmf_stats', '_gauge_stats') + + def __init__(self, name, *stats): + self._name = name + self._stats = {} + self._int_stats = {} + self._pmf_stats = {} + self._gauge_stats = {} + + for stat in stats: + self._stats[stat.name] = stat + if isinstance(stat, IntStat): + self._int_stats[stat.name] = stat + elif isinstance(stat, PmfStat): + self._pmf_stats[stat.name] = stat + elif isinstance(stat, Stat): + self._gauge_stats[stat.name] = stat + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(name) + try: + stats = object.__getattribute__(self, '_stats') + if name in stats: + return stats[name] + except AttributeError: + pass + raise AttributeError(f"No stat named '{name}'") + + def __setattr__(self, name, value): + if name.startswith('_'): + object.__setattr__(self, name, value) + return + # Allow rebinding stats (e.g., for augmented assignment like stats.errors += 1) + try: + stats = object.__getattribute__(self, '_stats') + if name in stats: + # For augmented assignment, value should be the same IntStat/PmfStat object + # Just verify and allow the rebind + return + except AttributeError: + pass + raise AttributeError(f"Cannot set attribute '{name}' on StatsCollection") + + def _get_stats_dict(self): + """Return dictionary representation of all stats.""" + result = {} + for name, stat in self._int_stats.items(): + result[name] = stat.value + for name, stat in self._pmf_stats.items(): + result[name] = stat._get_stats() + for name, stat in self._gauge_stats.items(): + result[name] = stat.value + return result + + +def collection(name, *stats): + """ + Create a named collection of statistics and register it globally. + """ + coll = StatsCollection(name, *stats) + with _registry_lock: + _stats_registry[name] = coll + return coll + + +def init(obj, path): + """ + Initialize class-level stats on an instance and register in the global registry. + + This allows class-level PmfStat/IntStat descriptors to be used per-instance. + """ + # Get class-level stats and create instance copies + cls = obj.__class__ + instance_stats = {} + + for attr_name in dir(cls): + attr = getattr(cls, attr_name, None) + if isinstance(attr, (PmfStat, IntStat)): + # Create a new instance of the stat for this object + if isinstance(attr, PmfStat): + new_stat = PmfStat(attr.name) + else: + new_stat = IntStat(attr.name) + instance_stats[attr_name] = new_stat + # Set on instance to shadow class attribute + object.__setattr__(obj, attr_name, new_stat) + + # Register under the given path (remove leading /) + reg_name = path.lstrip('/') + if instance_stats: + stats_coll = StatsCollection(reg_name, *instance_stats.values()) + with _registry_lock: + _stats_registry[reg_name] = stats_coll + + class Metrics(object): """ A collection of timers and counters for various performance metrics. @@ -34,7 +310,7 @@ class Metrics(object): request_timer = None """ - A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like + A :class:`~cassandra.metrics.PmfStat` timer for requests. This is a dict-like object with the following keys: * count - number of requests that have been timed @@ -52,64 +328,64 @@ class Metrics(object): connection_errors = None """ - A :class:`greplin.scales.IntStat` count of the number of times that a + A :class:`~cassandra.metrics.IntStat` count of the number of times that a request to a Cassandra node has failed due to a connection problem. """ write_timeouts = None """ - A :class:`greplin.scales.IntStat` count of write requests that resulted + A :class:`~cassandra.metrics.IntStat` count of write requests that resulted in a timeout. """ read_timeouts = None """ - A :class:`greplin.scales.IntStat` count of read requests that resulted + A :class:`~cassandra.metrics.IntStat` count of read requests that resulted in a timeout. """ unavailables = None """ - A :class:`greplin.scales.IntStat` count of write or read requests that + A :class:`~cassandra.metrics.IntStat` count of write or read requests that failed due to an insufficient number of replicas being alive to meet the requested :class:`.ConsistencyLevel`. """ other_errors = None """ - A :class:`greplin.scales.IntStat` count of all other request failures, + A :class:`~cassandra.metrics.IntStat` count of all other request failures, including failures caused by invalid requests, bootstrapping nodes, overloaded nodes, etc. """ retries = None """ - A :class:`greplin.scales.IntStat` count of the number of times a + A :class:`~cassandra.metrics.IntStat` count of the number of times a request was retried based on the :class:`.RetryPolicy` decision. """ ignores = None """ - A :class:`greplin.scales.IntStat` count of the number of times a + A :class:`~cassandra.metrics.IntStat` count of the number of times a failed request was ignored based on the :class:`.RetryPolicy` decision. """ known_hosts = None """ - A :class:`greplin.scales.IntStat` count of the number of nodes in + A :class:`~cassandra.metrics.IntStat` count of the number of nodes in the cluster that the driver is aware of, regardless of whether any connections are opened to those nodes. """ connected_to = None """ - A :class:`greplin.scales.IntStat` count of the number of nodes that + A :class:`~cassandra.metrics.IntStat` count of the number of nodes that the driver currently has at least one connection open to. """ open_connections = None """ - A :class:`greplin.scales.IntStat` count of the number connections + A :class:`~cassandra.metrics.IntStat` count of the number connections the driver currently has open. """ @@ -120,28 +396,29 @@ def __init__(self, cluster_proxy): self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) Metrics._stats_counter += 1 - self.stats = scales.collection(self.stats_name, - scales.PmfStat('request_timer'), - scales.IntStat('connection_errors'), - scales.IntStat('write_timeouts'), - scales.IntStat('read_timeouts'), - scales.IntStat('unavailables'), - scales.IntStat('other_errors'), - scales.IntStat('retries'), - scales.IntStat('ignores'), + self.stats = collection(self.stats_name, + PmfStat('request_timer'), + IntStat('connection_errors'), + IntStat('write_timeouts'), + IntStat('read_timeouts'), + IntStat('unavailables'), + IntStat('other_errors'), + IntStat('retries'), + IntStat('ignores'), # gauges - scales.Stat('known_hosts', + Stat('known_hosts', lambda: len(cluster_proxy.metadata.all_hosts())), - scales.Stat('connected_to', + Stat('connected_to', lambda: len(set(chain.from_iterable(list(s._pools.keys()) for s in cluster_proxy.sessions)))), - scales.Stat('open_connections', + Stat('open_connections', lambda: sum(sum(p.open_count for p in list(s._pools.values())) for s in cluster_proxy.sessions))) # TODO, to be removed in 4.0 # /cassandra contains the metrics of the first cluster registered - if 'cassandra' not in scales._Stats.stats: - scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] + with _registry_lock: + if 'cassandra' not in _stats_registry: + _stats_registry['cassandra'] = _stats_registry[self.stats_name] self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors @@ -180,22 +457,23 @@ def get_stats(self): """ Returns the metrics for the registered cluster instance. """ - return scales.getStats()[self.stats_name] + return getStats()[self.stats_name] def set_stats_name(self, stats_name): """ Set the metrics stats name. - The stats_name is a string used to access the metris through scales: scales.getStats()[] + The stats_name is a string used to access the metrics through getStats(): getStats()[] Default is 'cassandra-'. """ if self.stats_name == stats_name: return - if stats_name in scales._Stats.stats: - raise ValueError('"{0}" already exists in stats.'.format(stats_name)) + with _registry_lock: + if stats_name in _stats_registry: + raise ValueError('"{0}" already exists in stats.'.format(stats_name)) - stats = scales._Stats.stats[self.stats_name] - del scales._Stats.stats[self.stats_name] - self.stats_name = stats_name - scales._Stats.stats[self.stats_name] = stats + stats = _stats_registry[self.stats_name] + del _stats_registry[self.stats_name] + self.stats_name = stats_name + _stats_registry[self.stats_name] = stats diff --git a/docs/installation.rst b/docs/installation.rst index 8e4e54e036..4efd87f07a 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -47,15 +47,6 @@ For snappy support:: (If using a Debian Linux derivative such as Ubuntu, it may be easier to just run ``apt-get install python-snappy``.) -(*Optional*) Metrics Support ----------------------------- -The driver has built-in support for capturing :attr:`.Cluster.metrics` about -the queries you run. However, the ``scales`` library is required to -support this:: - - pip install scales - - Speeding Up Installation ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/pyproject.toml b/docs/pyproject.toml index f6ee417aee..460e6a5609 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "sphinx-scylladb-theme>=1.8.2,<2.0.0", "sphinx-multiversion-scylla>=0.3.2,<1.0.0", "sphinx>=8.2.3,<9.0.0", - "scales>=1.0.9,<2.0.0", "six>=1.9", "tornado>=6.5,<7.0", ] diff --git a/examples/request_init_listener.py b/examples/request_init_listener.py index 2ca6df495a..e23ac80fbf 100644 --- a/examples/request_init_listener.py +++ b/examples/request_init_listener.py @@ -19,7 +19,7 @@ # this is just demonstrating a way to track a few custom attributes. from cassandra.cluster import Cluster -from greplin import scales +from cassandra.metrics import PmfStat, IntStat, init import pprint pp = pprint.PrettyPrinter(indent=2) @@ -32,11 +32,11 @@ class RequestAnalyzer(object): Also computes statistics on encoded request size. """ - requests = scales.PmfStat('request size') - errors = scales.IntStat('errors') + requests = PmfStat('request size') + errors = IntStat('errors') def __init__(self, session): - scales.init(self, '/cassandra') + init(self, '/cassandra') # each instance will be registered with a session, and receive a callback for each request generated session.add_request_init_listener(self.on_request) @@ -91,7 +91,7 @@ def __str__(self): pass print() -print(ra) # note: the counts are updated, but the stats are not because scales only updates every 20s +print(ra) # 3 requests (1 errors) # Request size statistics: # { '75percentile': 74, diff --git a/pyproject.toml b/pyproject.toml index 1c195e1b77..7dd3320fc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ auth-kerberos = [ dev = [ "pytest~=8.0", "PyYAML", - "scales", "pure-sasl", "twisted[tls]", "gevent", diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 48c7b49b95..aa9690623d 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -25,7 +25,7 @@ from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT from tests.integration import get_cluster, get_node, use_singledc, execute_until_pass, TestCluster -from greplin import scales +from cassandra import metrics from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local import pprint as pp @@ -223,7 +223,7 @@ def test_metrics_per_cluster(self): finally: get_node(1).resume() - # Change the scales stats_name of the cluster2 + # Change the stats_name of the cluster2 cluster2.metrics.set_stats_name('cluster2-metrics') stats_cluster1 = self.cluster.metrics.get_stats() @@ -242,7 +242,7 @@ def test_metrics_per_cluster(self): assert 0.0 == stats_cluster2['request_timer']['mean'] # Test access by stats_name - assert 0.0 == scales.getStats()['cluster2-metrics']['request_timer']['mean'] + assert 0.0 == metrics.getStats()['cluster2-metrics']['request_timer']['mean'] cluster2.shutdown() @@ -289,9 +289,9 @@ def test_duplicate_metrics_per_cluster(self): assert cluster2.metrics.get_stats()['request_timer']['count'] == 10 assert cluster3.metrics.get_stats()['request_timer']['count'] == 5 - # Check scales to ensure they are appropriately named - assert "appcluster" in scales._Stats.stats.keys() - assert "devops" in scales._Stats.stats.keys() + # Check registry to ensure they are appropriately named + assert "appcluster" in metrics._stats_registry.keys() + assert "devops" in metrics._stats_registry.keys() cluster2.shutdown() cluster3.shutdown() @@ -303,15 +303,15 @@ class RequestAnalyzer(object): Also computes statistics on encoded request size. """ - requests = scales.PmfStat('request size') - errors = scales.IntStat('errors') - successful = scales.IntStat("success") + requests = metrics.PmfStat('request size') + errors = metrics.IntStat('errors') + successful = metrics.IntStat("success") # Throw exceptions when invoked. throw_on_success = False throw_on_fail = False def __init__(self, session, throw_on_success=False, throw_on_fail=False): - scales.init(self, '/request') + metrics.init(self, '/request') # each instance will be registered with a session, and receive a callback for each request generated session.add_request_init_listener(self.on_request) self.throw_on_fail = throw_on_fail diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py new file mode 100644 index 0000000000..88b1bff4ff --- /dev/null +++ b/tests/unit/test_metrics.py @@ -0,0 +1,304 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the self-contained metrics module. +""" + +import threading +import unittest + +from cassandra.metrics import ( + IntStat, Stat, PmfStat, StatsCollection, + collection, init, getStats, _stats_registry, _registry_lock +) + + +class IntStatTest(unittest.TestCase): + """Tests for IntStat class.""" + + def test_initial_value(self): + stat = IntStat('test_counter') + self.assertEqual(stat.value, 0) + self.assertEqual(int(stat), 0) + + def test_increment(self): + stat = IntStat('test_counter') + stat += 1 + self.assertEqual(stat.value, 1) + stat += 5 + self.assertEqual(stat.value, 6) + + def test_thread_safety(self): + stat = IntStat('test_counter') + num_threads = 10 + increments_per_thread = 1000 + + def increment(): + nonlocal stat + for _ in range(increments_per_thread): + stat += 1 + + threads = [threading.Thread(target=increment) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(stat.value, num_threads * increments_per_thread) + + def test_repr(self): + stat = IntStat('my_counter') + stat += 42 + self.assertEqual(repr(stat), "IntStat(my_counter=42)") + + +class StatTest(unittest.TestCase): + """Tests for Stat (gauge) class.""" + + def test_basic_gauge(self): + counter = [0] + stat = Stat('test_gauge', lambda: counter[0]) + + self.assertEqual(stat.value, 0) + counter[0] = 10 + self.assertEqual(stat.value, 10) + counter[0] = 42 + self.assertEqual(stat.value, 42) + + def test_repr(self): + stat = Stat('my_gauge', lambda: 123) + self.assertEqual(repr(stat), "Stat(my_gauge=123)") + + +class PmfStatTest(unittest.TestCase): + """Tests for PmfStat class.""" + + def test_empty_stats(self): + stat = PmfStat('test_timer') + stats = stat._get_stats() + + self.assertEqual(stats['count'], 0) + self.assertEqual(stats['min'], 0.0) + self.assertEqual(stats['max'], 0.0) + self.assertEqual(stats['mean'], 0.0) + self.assertEqual(stats['stddev'], 0.0) + self.assertEqual(stats['median'], 0.0) + + def test_single_value(self): + stat = PmfStat('test_timer') + stat.addValue(10.0) + stats = stat._get_stats() + + self.assertEqual(stats['count'], 1) + self.assertEqual(stats['min'], 10.0) + self.assertEqual(stats['max'], 10.0) + self.assertEqual(stats['mean'], 10.0) + self.assertEqual(stats['stddev'], 0.0) + self.assertEqual(stats['median'], 10.0) + + def test_multiple_values(self): + stat = PmfStat('test_timer') + for v in [1, 2, 3, 4, 5]: + stat.addValue(v) + stats = stat._get_stats() + + self.assertEqual(stats['count'], 5) + self.assertEqual(stats['min'], 1.0) + self.assertEqual(stats['max'], 5.0) + self.assertEqual(stats['mean'], 3.0) + self.assertEqual(stats['median'], 3.0) + + def test_dict_like_access(self): + stat = PmfStat('test_timer') + stat.addValue(5.0) + + self.assertEqual(stat['count'], 1) + self.assertEqual(stat['mean'], 5.0) + self.assertIn('count', stat.keys()) + self.assertIn('mean', stat.keys()) + + def test_percentiles(self): + stat = PmfStat('test_timer') + # Add values 1-100 + for v in range(1, 101): + stat.addValue(v) + stats = stat._get_stats() + + self.assertEqual(stats['count'], 100) + self.assertEqual(stats['min'], 1.0) + self.assertEqual(stats['max'], 100.0) + # Median should be around 50 + self.assertAlmostEqual(stats['median'], 50.5, delta=1) + # 75th percentile should be around 75 + self.assertAlmostEqual(stats['75percentile'], 75.25, delta=1) + # 95th percentile should be around 95 + self.assertAlmostEqual(stats['95percentile'], 95.05, delta=1) + + def test_thread_safety(self): + stat = PmfStat('test_timer') + num_threads = 10 + values_per_thread = 100 + + def add_values(): + for i in range(values_per_thread): + stat.addValue(i) + + threads = [threading.Thread(target=add_values) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + stats = stat._get_stats() + self.assertEqual(stats['count'], num_threads * values_per_thread) + + +class StatsCollectionTest(unittest.TestCase): + """Tests for StatsCollection class.""" + + def test_access_stats_by_attribute(self): + int_stat = IntStat('errors') + pmf_stat = PmfStat('latency') + + coll = StatsCollection('test', int_stat, pmf_stat) + + self.assertIs(coll.errors, int_stat) + self.assertIs(coll.latency, pmf_stat) + + def test_augmented_assignment(self): + int_stat = IntStat('errors') + coll = StatsCollection('test', int_stat) + + coll.errors += 1 + self.assertEqual(int_stat.value, 1) + coll.errors += 5 + self.assertEqual(int_stat.value, 6) + + def test_get_stats_dict(self): + int_stat = IntStat('errors') + int_stat += 3 + pmf_stat = PmfStat('latency') + pmf_stat.addValue(10.0) + gauge = Stat('connections', lambda: 5) + + coll = StatsCollection('test', int_stat, pmf_stat, gauge) + stats_dict = coll._get_stats_dict() + + self.assertEqual(stats_dict['errors'], 3) + self.assertEqual(stats_dict['connections'], 5) + self.assertIsInstance(stats_dict['latency'], dict) + self.assertEqual(stats_dict['latency']['count'], 1) + self.assertEqual(stats_dict['latency']['mean'], 10.0) + + def test_nonexistent_attribute(self): + coll = StatsCollection('test', IntStat('errors')) + with self.assertRaises(AttributeError): + _ = coll.nonexistent + + def test_cannot_set_new_attribute(self): + coll = StatsCollection('test', IntStat('errors')) + with self.assertRaises(AttributeError): + coll.new_attr = 123 + + +class CollectionFunctionTest(unittest.TestCase): + """Tests for the collection() function.""" + + def setUp(self): + # Clean up registry before each test + with _registry_lock: + keys_to_remove = [k for k in _stats_registry.keys() + if k.startswith('test_')] + for k in keys_to_remove: + del _stats_registry[k] + + def test_registers_in_global_registry(self): + coll = collection('test_coll', IntStat('counter')) + + with _registry_lock: + self.assertIn('test_coll', _stats_registry) + self.assertIs(_stats_registry['test_coll'], coll) + + def test_get_stats_returns_dict(self): + int_stat = IntStat('counter') + int_stat += 10 + collection('test_stats', int_stat) + + stats = getStats() + self.assertIn('test_stats', stats) + self.assertEqual(stats['test_stats']['counter'], 10) + + +class InitFunctionTest(unittest.TestCase): + """Tests for the init() function.""" + + def setUp(self): + # Clean up registry before each test + with _registry_lock: + keys_to_remove = [k for k in _stats_registry.keys() + if k.startswith('test') or k == 'request'] + for k in keys_to_remove: + del _stats_registry[k] + + def test_creates_instance_stats(self): + class Analyzer: + requests = PmfStat('request size') + errors = IntStat('errors') + + def __init__(self): + init(self, '/test_analyzer') + + analyzer = Analyzer() + + # Instance should have its own stats + self.assertIsInstance(analyzer.requests, PmfStat) + self.assertIsInstance(analyzer.errors, IntStat) + + # Should be different from class-level stats + self.assertIsNot(analyzer.requests, Analyzer.requests) + self.assertIsNot(analyzer.errors, Analyzer.errors) + + def test_instance_stats_are_independent(self): + class Analyzer: + errors = IntStat('errors') + + def __init__(self): + init(self, '/test_analyzer') + + a1 = Analyzer() + a2 = Analyzer() + + a1.errors += 5 + a2.errors += 10 + + self.assertEqual(a1.errors.value, 5) + self.assertEqual(a2.errors.value, 10) + + def test_strips_leading_slash(self): + class Analyzer: + errors = IntStat('errors') + + def __init__(self): + init(self, '/test_path') + + Analyzer() + + with _registry_lock: + self.assertIn('test_path', _stats_registry) + self.assertNotIn('/test_path', _stats_registry) + + +if __name__ == '__main__': + unittest.main() From 70d218bd1c0b9ccd1f42f41299ce826b1b1f22f4 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sun, 1 Feb 2026 02:25:49 -0400 Subject: [PATCH 2/2] Fix IntStat comparison operators and metrics cleanup on shutdown - Add comparison operators (__eq__, __ne__, __lt__, __le__, __gt__, __ge__) to IntStat class for direct comparison with integers - Add __hash__ method to IntStat for use in sets/dicts - Add Metrics.shutdown() method to remove stats from global registry - Call metrics.shutdown() from Cluster.shutdown() to prevent stale weakref errors when accessing getStats() after cluster shutdown - Fix test using 'is' instead of '==' for IntStat comparison - Add unit tests for new comparison operators and shutdown functionality --- cassandra/cluster.py | 3 + cassandra/metrics.py | 43 ++++++++++++ tests/integration/standard/test_metrics.py | 4 +- tests/unit/test_metrics.py | 82 ++++++++++++++++++++++ 4 files changed, 130 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 099043eae0..622b706330 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1795,6 +1795,9 @@ def shutdown(self): self.executor.shutdown() + if self.metrics_enabled and self.metrics: + self.metrics.shutdown() + _discard_cluster_shutdown(self) def __enter__(self): diff --git a/cassandra/metrics.py b/cassandra/metrics.py index 6a1af793ec..7ff44107af 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -61,6 +61,37 @@ def __iadd__(self, other): def __int__(self): return self._value + def __eq__(self, other): + if isinstance(other, IntStat): + return self._value == other._value + return self._value == other + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + if isinstance(other, IntStat): + return self._value < other._value + return self._value < other + + def __le__(self, other): + if isinstance(other, IntStat): + return self._value <= other._value + return self._value <= other + + def __gt__(self, other): + if isinstance(other, IntStat): + return self._value > other._value + return self._value > other + + def __ge__(self, other): + if isinstance(other, IntStat): + return self._value >= other._value + return self._value >= other + + def __hash__(self): + return hash(self._value) + def __repr__(self): return f"IntStat({self.name}={self._value})" @@ -477,3 +508,15 @@ def set_stats_name(self, stats_name): del _stats_registry[self.stats_name] self.stats_name = stats_name _stats_registry[self.stats_name] = stats + + def shutdown(self): + """ + Remove this metrics instance from the global registry. + Called when the cluster is shutdown to prevent stale references. + """ + with _registry_lock: + if self.stats_name in _stats_registry: + del _stats_registry[self.stats_name] + # Also clean up the legacy 'cassandra' entry if it points to our stats + if _stats_registry.get('cassandra') is self.stats: + del _stats_registry['cassandra'] diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index aa9690623d..7b502d91c3 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -355,10 +355,10 @@ def setUpClass(cls): def wait_for_count(self, ra, expected_count, error=False): for _ in range(10): if not error: - if ra.successful is expected_count: + if ra.successful == expected_count: return True else: - if ra.errors is expected_count: + if ra.errors == expected_count: return True time.sleep(.01) return False diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 88b1bff4ff..3a8d1d2432 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -63,6 +63,35 @@ def test_repr(self): stat += 42 self.assertEqual(repr(stat), "IntStat(my_counter=42)") + def test_equality(self): + stat = IntStat('test') + stat += 5 + self.assertEqual(stat, 5) + self.assertEqual(5, stat) + self.assertNotEqual(stat, 3) + self.assertNotEqual(stat, 10) + + def test_comparison_operators(self): + stat = IntStat('test') + stat += 5 + self.assertTrue(stat > 0) + self.assertTrue(stat >= 5) + self.assertTrue(stat < 10) + self.assertTrue(stat <= 5) + self.assertFalse(stat > 5) + self.assertFalse(stat < 5) + + def test_comparison_with_intstat(self): + stat1 = IntStat('test1') + stat2 = IntStat('test2') + stat1 += 5 + stat2 += 5 + self.assertEqual(stat1, stat2) + stat2 += 1 + self.assertNotEqual(stat1, stat2) + self.assertTrue(stat1 < stat2) + self.assertTrue(stat2 > stat1) + class StatTest(unittest.TestCase): """Tests for Stat (gauge) class.""" @@ -300,5 +329,58 @@ def __init__(self): self.assertNotIn('/test_path', _stats_registry) +class MetricsShutdownTest(unittest.TestCase): + """Tests for Metrics shutdown functionality.""" + + def setUp(self): + # Clean up registry before each test + with _registry_lock: + keys_to_remove = [k for k in _stats_registry.keys() + if k.startswith('cassandra')] + for k in keys_to_remove: + del _stats_registry[k] + + def test_shutdown_removes_from_registry(self): + import weakref + from cassandra.metrics import Metrics + + class MockCluster: + def __init__(self): + self.metadata = type('obj', (object,), {'all_hosts': lambda: []})() + self.sessions = [] + + cluster = MockCluster() + proxy = weakref.proxy(cluster) + + metrics = Metrics(proxy) + stats_name = metrics.stats_name + + with _registry_lock: + self.assertIn(stats_name, _stats_registry) + + metrics.shutdown() + + with _registry_lock: + self.assertNotIn(stats_name, _stats_registry) + + def test_shutdown_is_idempotent(self): + import weakref + from cassandra.metrics import Metrics + + class MockCluster: + def __init__(self): + self.metadata = type('obj', (object,), {'all_hosts': lambda: []})() + self.sessions = [] + + cluster = MockCluster() + proxy = weakref.proxy(cluster) + + metrics = Metrics(proxy) + + # Should not raise even if called multiple times + metrics.shutdown() + metrics.shutdown() + + if __name__ == '__main__': unittest.main()