From 9ea1ed1782e2054a1adf73388a4c7981860135e0 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 13 Mar 2026 11:31:48 +0200 Subject: [PATCH] (improvement) cache namedtuple class in named_tuple_factory to avoid repeated exec() calls Cache the Row namedtuple class keyed on tuple(colnames) so Python's namedtuple() (which internally calls exec()) is only invoked once per unique column schema. For prepared statements the column names never change, eliminating redundant class creation on every result set. ## Motivation named_tuple_factory is the default row_factory in the driver. Every call to namedtuple('Row', columns) internally calls exec() to generate a new class -- this is surprisingly expensive. For prepared statements executing the same query repeatedly, the column names never change, yet we pay the namedtuple() + exec() cost on every result set. ## Benchmark results Benchmarks compare the original code (Before) against the new cached implementation (After). 10 columns, 1 row (isolates class creation overhead): | Variant | Min | Mean | Median | Ops/sec | |---|---|---|---|---| | Before (original) | 43,490 ns | 59,976 ns | 47,653 ns | 16.7 Kops/s | | After (with cache) | 235 ns | 452 ns | 353 ns | 2,210 Kops/s | 5 columns, 100 rows: | Variant | Min | Mean | Median | Ops/sec | |---|---|---|---|---| | Before (original) | 57.4 us | 91.2 us | 65.8 us | 10,969/s | | After (with cache) | 19.3 us | 25.3 us | 24.0 us | 39,594/s | 10 columns, 100 rows: | Variant | Min | Mean | Median | Ops/sec | |---|---|---|---|---| | Before (original) | 56.7 us | 101.9 us | 75.6 us | 9,813/s | | After (with cache) | 18.1 us | 21.4 us | 20.4 us | 46,825/s | ## Design notes - Cache is a plain dict keyed on tuple(colnames) (raw column names before cleaning) - Error handling paths (SyntaxError, Exception) preserved unchanged - Cache is naturally bounded by the number of distinct queries ## Tests All existing unit tests pass (46 passed). --- .../test_named_tuple_factory_benchmark.py | 206 +++++++ cassandra/query.py | 501 +++++++++++++----- 2 files changed, 571 insertions(+), 136 deletions(-) create mode 100644 benchmarks/test_named_tuple_factory_benchmark.py diff --git a/benchmarks/test_named_tuple_factory_benchmark.py b/benchmarks/test_named_tuple_factory_benchmark.py new file mode 100644 index 0000000000..ed4b7be8ca --- /dev/null +++ b/benchmarks/test_named_tuple_factory_benchmark.py @@ -0,0 +1,206 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmarks for named_tuple_factory with and without namedtuple class caching. + +Run with: pytest benchmarks/test_named_tuple_factory_benchmark.py -v +""" + +import re +import warnings +from collections import namedtuple + +import pytest + +from cassandra.query import named_tuple_factory, _named_tuple_cache +from cassandra.util import _sanitize_identifiers + + +# --------------------------------------------------------------------------- +# Reference: original uncached implementation (copied from master) +# --------------------------------------------------------------------------- + +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") + +_clean_name_cache_old = {} + + +def _clean_column_name_old(name): + try: + return _clean_name_cache_old[name] + except KeyError: + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) + _clean_name_cache_old[name] = clean + return clean + + +def named_tuple_factory_uncached(colnames, rows): + """Original implementation without caching (for benchmark comparison).""" + clean_column_names = map(_clean_column_name_old, colnames) + try: + Row = namedtuple("Row", clean_column_names) + except SyntaxError: + raise + except Exception: + clean_column_names = list(map(_clean_column_name_old, colnames)) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) + return [Row(*row) for row in rows] + + +# --------------------------------------------------------------------------- +# Test data generators +# --------------------------------------------------------------------------- + + +def make_colnames(n): + return tuple(f"col_{i}" for i in range(n)) + + +def make_rows(ncols, nrows): + return [tuple(range(ncols)) for _ in range(nrows)] + + +# --------------------------------------------------------------------------- +# Correctness tests +# --------------------------------------------------------------------------- + + +class TestNamedTupleFactoryCorrectness: + """Verify the cached implementation matches the uncached one.""" + + @pytest.mark.parametrize("ncols", [1, 5, 10, 20]) + @pytest.mark.parametrize("nrows", [1, 10, 100]) + def test_results_match(self, ncols, nrows): + colnames = make_colnames(ncols) + rows = make_rows(ncols, nrows) + _named_tuple_cache.clear() + cached_result = named_tuple_factory(colnames, rows) + uncached_result = named_tuple_factory_uncached(colnames, rows) + assert len(cached_result) == len(uncached_result) + for cr, ur in zip(cached_result, uncached_result): + assert tuple(cr) == tuple(ur) + assert cr._fields == ur._fields + + def test_cache_hit_returns_same_class(self): + colnames = ("name", "age", "email") + rows1 = [("Alice", 30, "a@b.com")] + rows2 = [("Bob", 25, "b@c.com")] + _named_tuple_cache.clear() + result1 = named_tuple_factory(colnames, rows1) + result2 = named_tuple_factory(colnames, rows2) + # Same Row class should be reused + assert type(result1[0]) is type(result2[0]) + + def test_different_schemas_get_different_classes(self): + _named_tuple_cache.clear() + result1 = named_tuple_factory(("a", "b"), [(1, 2)]) + result2 = named_tuple_factory(("x", "y"), [(3, 4)]) + assert type(result1[0]) is not type(result2[0]) + assert result1[0]._fields == ("a", "b") + assert result2[0]._fields == ("x", "y") + + +# --------------------------------------------------------------------------- +# Benchmarks +# --------------------------------------------------------------------------- + + +class TestNamedTupleFactoryBenchmark: + """Benchmark cached vs uncached named_tuple_factory.""" + + # --- 5 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_5cols_100rows") + def test_uncached_5cols_100rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_5cols_100rows") + def test_cached_5cols_100rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 100) + _named_tuple_cache.clear() + # Warm the cache with one call + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 10 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_10cols_100rows") + def test_uncached_10cols_100rows(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_10cols_100rows") + def test_cached_10cols_100rows(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 100) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 20 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_20cols_100rows") + def test_uncached_20cols_100rows(self, benchmark): + colnames = make_colnames(20) + rows = make_rows(20, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_20cols_100rows") + def test_cached_20cols_100rows(self, benchmark): + colnames = make_colnames(20) + rows = make_rows(20, 100) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 5 columns, 1000 rows --- + + @pytest.mark.benchmark(group="ntf_5cols_1000rows") + def test_uncached_5cols_1000rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 1000) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_5cols_1000rows") + def test_cached_5cols_1000rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 1000) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 10 columns, 1 row (measures class creation overhead most clearly) --- + + @pytest.mark.benchmark(group="ntf_10cols_1row") + def test_uncached_10cols_1row(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 1) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_10cols_1row") + def test_cached_10cols_1row(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 1) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..1add236d56 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -34,6 +34,7 @@ from cassandra.util import OrderedDict, _sanitize_identifiers import logging + log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE @@ -49,9 +50,9 @@ Only valid when using native protocol v4+ """ -NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') -START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') -END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") _clean_name_cache = {} @@ -60,7 +61,9 @@ def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: - clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) _clean_name_cache[name] = clean return clean @@ -83,6 +86,7 @@ def tuple_factory(colnames, rows): """ return rows + class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an @@ -90,6 +94,7 @@ class PseudoNamedTupleRow(object): but otherwise do not attempt to implement the full namedtuple or iterable interface. """ + def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) @@ -104,8 +109,7 @@ def __iter__(self): return iter(self._tuple) def __repr__(self): - return '{t}({od})'.format(t=self.__class__.__name__, - od=self._dict) + return "{t}({od})".format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): @@ -113,8 +117,13 @@ def pseudo_namedtuple_factory(colnames, rows): Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ - return [PseudoNamedTupleRow(od) - for od in ordered_dict_factory(colnames, rows)] + return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] + + +# Cache namedtuple Row classes to avoid repeated exec() calls in namedtuple() +# for the same column schema. Naturally bounded by the number of distinct +# column-name tuples, which equals the number of distinct queries. +_named_tuple_cache = {} def named_tuple_factory(colnames, rows): @@ -146,32 +155,41 @@ def named_tuple_factory(colnames, rows): .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ - clean_column_names = map(_clean_column_name, colnames) + key = tuple(colnames) try: - Row = namedtuple('Row', clean_column_names) - except SyntaxError: - warnings.warn( - "Failed creating namedtuple for a result because there were too " - "many columns. This is due to a Python limitation that affects " - "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " - "created with {substitute_factory_name}, which lacks some namedtuple " - "features and is slower. To avoid slower performance accessing " - "values on row objects, Upgrade to Python 3.7, or use a different " - "row factory. (column names: {colnames})".format( - substitute_factory_name=pseudo_namedtuple_factory.__name__, - colnames=colnames + Row = _named_tuple_cache[key] + except KeyError: + clean_column_names = map(_clean_column_name, colnames) + try: + Row = namedtuple("Row", clean_column_names) + except SyntaxError: + warnings.warn( + "Failed creating namedtuple for a result because there were too " + "many columns. This is due to a Python limitation that affects " + "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " + "created with {substitute_factory_name}, which lacks some namedtuple " + "features and is slower. To avoid slower performance accessing " + "values on row objects, Upgrade to Python 3.7, or use a different " + "row factory. (column names: {colnames})".format( + substitute_factory_name=pseudo_namedtuple_factory.__name__, + colnames=colnames, + ) ) - ) - return pseudo_namedtuple_factory(colnames, rows) - except Exception: - clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt - log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " - "(see Python 'namedtuple' documentation for details on name rules). " - "Results will be returned with positional names. " - "Avoid this by choosing different names, using SELECT \"\" AS aliases, " - "or specifying a different row_factory on your Session" % - (colnames, clean_column_names)) - Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + return pseudo_namedtuple_factory(colnames, rows) + except Exception: + clean_column_names = list( + map(_clean_column_name, colnames) + ) # create list because py3 map object will be consumed by first attempt + log.warning( + "Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + 'Avoid this by choosing different names, using SELECT "" AS aliases, ' + "or specifying a different row_factory on your Session" + % (colnames, clean_column_names) + ) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) + _named_tuple_cache[key] = Row return [Row(*row) for row in rows] @@ -276,11 +294,24 @@ class Statement(object): _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False, table=None): - if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors - raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + def __init__( + self, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + table=None, + ): + if retry_policy and not hasattr( + retry_policy, "on_read_timeout" + ): # just checking one method to detect positional parameter errors + raise ValueError( + "retry_policy should implement cassandra.policies.RetryPolicy" + ) if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: @@ -329,17 +360,20 @@ def _del_routing_key(self): If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. - """) + """, + ) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL") + "or ConsistencyLevel.LOCAL_SERIAL" + ) self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): @@ -384,7 +418,8 @@ def is_lwt(self): conditional statements. .. versionadded:: 2.0.0 - """) + """, + ) class SimpleStatement(Statement): @@ -392,9 +427,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None, is_idempotent=False): + def __init__( + self, + query_string, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + ): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -402,8 +446,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + is_idempotent, + ) self._query_string = query_string @property @@ -411,9 +464,14 @@ def query_string(self): return self._query_string def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -442,7 +500,7 @@ class PreparedStatement(object): A note about * in prepared statements """ - column_metadata = None #TODO: make this bind_metadata in next major + column_metadata = None # TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None @@ -459,9 +517,19 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, - is_lwt=False, column_encryption_policy=None): + def __init__( + self, + column_metadata, + query_id, + routing_key_indexes, + query, + keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt=False, + column_encryption_policy=None, + ): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -475,13 +543,33 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self._is_lwt = is_lwt @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy=None): + def from_message( + cls, + query_id, + column_metadata, + pk_indexes, + cluster_metadata, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy=None, + ): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + return PreparedStatement( + column_metadata, + query_id, + None, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) if pk_indexes: routing_key_indexes = pk_indexes @@ -496,18 +584,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) + statement_indexes = dict( + (c.name, i) for i, c in enumerate(column_metadata) + ) # a list of which indexes in the statement correspond to partition key items try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: # we're missing a partition key component in the prepared - pass # statement; just leave routing_key_indexes as None - - return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + routing_key_indexes = [ + statement_indexes[c.name] for c in partition_key_columns + ] + except ( + KeyError + ): # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement( + column_metadata, + query_id, + routing_key_indexes, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) def bind(self, values): """ @@ -519,16 +621,23 @@ def bind(self, values): def is_routing_key_index(self, i): if self._routing_key_index_set is None: - self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + self._routing_key_index_set = ( + set(self.routing_key_indexes) if self.routing_key_indexes else set() + ) return i in self._routing_key_index_set def is_lwt(self): return self._is_lwt def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -548,9 +657,17 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + def __init__( + self, + prepared_statement, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + ): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. @@ -571,9 +688,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None self.keyspace = meta[0].keyspace_name self.table = meta[0].table_name - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, - prepared_statement.is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + prepared_statement.is_idempotent, + ) def bind(self, values): """ @@ -615,24 +740,29 @@ def bind(self, values): values.append(UNSET_VALUE) else: raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + "Column name `%s` not found in bound dict." % (col.name) + ) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( - "Too many arguments provided to bind() (got %d, expected %d)" % - (len(values), len(col_meta))) + "Too many arguments provided to bind() (got %d, expected %d)" + % (len(values), len(col_meta)) + ) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): + if ( + proto_version < 4 + and self.prepared_statement.routing_key_indexes + and value_len < len(self.prepared_statement.routing_key_indexes) + ): raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) + "Too few arguments provided to bind() (got %d, required %d for routing key)" + % (value_len, len(self.prepared_statement.routing_key_indexes)) + ) self.raw_values = values self.values = [] @@ -643,20 +773,30 @@ def bind(self, values): if proto_version >= 4: self._append_unset_value() else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + col_desc = ColDesc( + col_spec.keyspace_name, col_spec.table_name, col_spec.name + ) uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_type = ( + ce_policy.column_type(col_desc) if uses_ce else col_spec.type + ) col_bytes = col_type.serialize(value, proto_version) if uses_ce: col_bytes = ce_policy.encrypt(col_desc, col_bytes) self.values.append(col_bytes) except (TypeError, struct.error) as exc: actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + message = ( + 'Received an argument of invalid type for column "%s". ' + "Expected: %s, Got: %s; (%s)" + % (col_spec.name, col_spec.type, actual_type, exc) + ) raise TypeError(message) if proto_version >= 4: @@ -671,7 +811,10 @@ def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] - raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) + raise ValueError( + "Cannot bind UNSET_VALUE as a part of the routing key '%s'" + % col_meta.name + ) self.values.append(UNSET_VALUE) @property @@ -686,7 +829,9 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) + self._routing_key = b"".join( + self._key_parts_packed(self.values[i] for i in routing_indexes) + ) return self._routing_key @@ -694,9 +839,15 @@ def is_lwt(self): return self.prepared_statement.is_lwt() def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.prepared_statement.query_string, self.raw_values, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.prepared_statement.query_string, + self.raw_values, + consistency, + ) + __repr__ = __str__ @@ -731,7 +882,7 @@ def __str__(self): return self.name def __repr__(self): - return "BatchType.%s" % (self.name, ) + return "BatchType.%s" % (self.name,) BatchType.LOGGED = BatchType("LOGGED", 0) @@ -763,9 +914,15 @@ class BatchStatement(Statement): _session = None _is_lwt = False - def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, - session=None, custom_payload=None): + def __init__( + self, + batch_type=BatchType.LOGGED, + retry_policy=None, + consistency_level=None, + serial_consistency_level=None, + session=None, + custom_payload=None, + ): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -813,8 +970,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, self.batch_type = batch_type self._statements_and_parameters = [] self._session = session - Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + Statement.__init__( + self, + retry_policy=retry_policy, + consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, + custom_payload=custom_payload, + ) def clear(self): """ @@ -853,11 +1015,14 @@ def add(self, statement, parameters=None): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " - "to BatchStatement.add()") + "to BatchStatement.add()" + ) self._update_state(statement) if statement.is_lwt(): self._is_lwt = True - self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + self._add_statement_and_params( + True, statement.prepared_statement.query_id, statement.values + ) else: # it must be a SimpleStatement query_string = statement.query_string @@ -881,7 +1046,9 @@ def add_all(self, statements, parameters): def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: - raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + raise ValueError( + "Batch statement cannot contain more than %d statements." % 0xFFFF + ) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): @@ -907,9 +1074,15 @@ def __len__(self): return len(self._statements_and_parameters) def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.batch_type, len(self), consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return "" % ( + self.batch_type, + len(self), + consistency, + ) + __repr__ = __str__ @@ -931,7 +1104,9 @@ def __str__(self): def bind_params(query, params, encoder): if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + return query % dict( + (k, encoder.cql_encode_all_types(v)) for k, v in params.items() + ) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -940,6 +1115,7 @@ class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ + pass @@ -1000,7 +1176,9 @@ class QueryTrace(object): _session = None - _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_SESSIONS_FORMAT = ( + "SELECT * FROM system_traces.sessions WHERE session_id = %s" + ) _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 @@ -1029,18 +1207,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( - "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." + % (max_wait,) + ) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) - metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout + metadata_request_timeout = ( + self._session.cluster.control_connection + and self._session.cluster.control_connection._metadata_request_timeout + ) session_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_SESSIONS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), + (self.trace_id,), + time_spent, + max_wait, + ) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries session_row = session_results.one() if session_results else None - is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + is_complete = ( + session_row is not None + and session_row.duration is not None + and session_row.started_at is not None + ) if not session_results or (wait_for_complete and not is_complete): - time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + time.sleep(self._BASE_RETRY_SLEEP * (2**attempt)) attempt += 1 continue if is_complete: @@ -1049,29 +1245,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) if is_complete else None + self.duration = ( + timedelta(microseconds=session_row.duration) if is_complete else None + ) self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 - self.client = getattr(session_row, 'client', None) + self.client = getattr(session_row, "client", None) - log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + log.debug( + "Attempting to fetch trace events for trace ID: %s", self.trace_id + ) time_spent = time.time() - start event_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout), - consistency_level=query_cl), + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_EVENTS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), (self.trace_id,), time_spent, - max_wait) + max_wait, + ) log.debug("Fetched trace events for trace ID: %s", self.trace_id) - self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) - for r in event_results) + self.events = tuple( + TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results + ) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future( + query, parameters, trace=False, custom_payload=None, timeout=timeout + ) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() @@ -1079,12 +1288,22 @@ def _execute(self, query, parameters, time_spent, max_wait): try: return future.result() except OperationTimedOut: - raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + raise TraceUnavailable( + "Trace information was not available within %f seconds" % (max_wait,) + ) def __str__(self): - return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ - % (self.request_type, self.trace_id, self.coordinator, self.started_at, - self.duration, self.parameters) + return ( + "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" + % ( + self.request_type, + self.trace_id, + self.coordinator, + self.started_at, + self.duration, + self.parameters, + ) + ) class TraceEvent(object): @@ -1121,7 +1340,9 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) + self.datetime = datetime.fromtimestamp( + unix_time_from_uuid1(timeuuid), tz=timezone.utc + ) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -1130,7 +1351,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.thread_name = thread_name def __str__(self): - return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + return "%s on %s[%s] at %s" % ( + self.description, + self.source, + self.thread_name, + self.datetime, + ) # TODO remove next major since we can target using the `host` attribute of session.execute @@ -1139,9 +1365,12 @@ class HostTargetingStatement(object): Wraps any query statement and attaches a target host, making it usable in a targeted LBP without modifying the user's statement. """ + def __init__(self, inner_statement, target_host): - self.__class__ = type(inner_statement.__class__.__name__, - (self.__class__, inner_statement.__class__), - {}) - self.__dict__ = inner_statement.__dict__ - self.target_host = target_host + self.__class__ = type( + inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}, + ) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host