diff --git a/benchmarks/test_named_tuple_factory_benchmark.py b/benchmarks/test_named_tuple_factory_benchmark.py new file mode 100644 index 0000000000..ed4b7be8ca --- /dev/null +++ b/benchmarks/test_named_tuple_factory_benchmark.py @@ -0,0 +1,206 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmarks for named_tuple_factory with and without namedtuple class caching. + +Run with: pytest benchmarks/test_named_tuple_factory_benchmark.py -v +""" + +import re +import warnings +from collections import namedtuple + +import pytest + +from cassandra.query import named_tuple_factory, _named_tuple_cache +from cassandra.util import _sanitize_identifiers + + +# --------------------------------------------------------------------------- +# Reference: original uncached implementation (copied from master) +# --------------------------------------------------------------------------- + +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") + +_clean_name_cache_old = {} + + +def _clean_column_name_old(name): + try: + return _clean_name_cache_old[name] + except KeyError: + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) + _clean_name_cache_old[name] = clean + return clean + + +def named_tuple_factory_uncached(colnames, rows): + """Original implementation without caching (for benchmark comparison).""" + clean_column_names = map(_clean_column_name_old, colnames) + try: + Row = namedtuple("Row", clean_column_names) + except SyntaxError: + raise + except Exception: + clean_column_names = list(map(_clean_column_name_old, colnames)) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) + return [Row(*row) for row in rows] + + +# --------------------------------------------------------------------------- +# Test data generators +# --------------------------------------------------------------------------- + + +def make_colnames(n): + return tuple(f"col_{i}" for i in range(n)) + + +def make_rows(ncols, nrows): + return [tuple(range(ncols)) for _ in range(nrows)] + + +# --------------------------------------------------------------------------- +# Correctness tests +# --------------------------------------------------------------------------- + + +class TestNamedTupleFactoryCorrectness: + """Verify the cached implementation matches the uncached one.""" + + @pytest.mark.parametrize("ncols", [1, 5, 10, 20]) + @pytest.mark.parametrize("nrows", [1, 10, 100]) + def test_results_match(self, ncols, nrows): + colnames = make_colnames(ncols) + rows = make_rows(ncols, nrows) + _named_tuple_cache.clear() + cached_result = named_tuple_factory(colnames, rows) + uncached_result = named_tuple_factory_uncached(colnames, rows) + assert len(cached_result) == len(uncached_result) + for cr, ur in zip(cached_result, uncached_result): + assert tuple(cr) == tuple(ur) + assert cr._fields == ur._fields + + def test_cache_hit_returns_same_class(self): + colnames = ("name", "age", "email") + rows1 = [("Alice", 30, "a@b.com")] + rows2 = [("Bob", 25, "b@c.com")] + _named_tuple_cache.clear() + result1 = named_tuple_factory(colnames, rows1) + result2 = named_tuple_factory(colnames, rows2) + # Same Row class should be reused + assert type(result1[0]) is type(result2[0]) + + def test_different_schemas_get_different_classes(self): + _named_tuple_cache.clear() + result1 = named_tuple_factory(("a", "b"), [(1, 2)]) + result2 = named_tuple_factory(("x", "y"), [(3, 4)]) + assert type(result1[0]) is not type(result2[0]) + assert result1[0]._fields == ("a", "b") + assert result2[0]._fields == ("x", "y") + + +# --------------------------------------------------------------------------- +# Benchmarks +# --------------------------------------------------------------------------- + + +class TestNamedTupleFactoryBenchmark: + """Benchmark cached vs uncached named_tuple_factory.""" + + # --- 5 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_5cols_100rows") + def test_uncached_5cols_100rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_5cols_100rows") + def test_cached_5cols_100rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 100) + _named_tuple_cache.clear() + # Warm the cache with one call + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 10 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_10cols_100rows") + def test_uncached_10cols_100rows(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_10cols_100rows") + def test_cached_10cols_100rows(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 100) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 20 columns, 100 rows --- + + @pytest.mark.benchmark(group="ntf_20cols_100rows") + def test_uncached_20cols_100rows(self, benchmark): + colnames = make_colnames(20) + rows = make_rows(20, 100) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_20cols_100rows") + def test_cached_20cols_100rows(self, benchmark): + colnames = make_colnames(20) + rows = make_rows(20, 100) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 5 columns, 1000 rows --- + + @pytest.mark.benchmark(group="ntf_5cols_1000rows") + def test_uncached_5cols_1000rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 1000) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_5cols_1000rows") + def test_cached_5cols_1000rows(self, benchmark): + colnames = make_colnames(5) + rows = make_rows(5, 1000) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) + + # --- 10 columns, 1 row (measures class creation overhead most clearly) --- + + @pytest.mark.benchmark(group="ntf_10cols_1row") + def test_uncached_10cols_1row(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 1) + benchmark(named_tuple_factory_uncached, colnames, rows) + + @pytest.mark.benchmark(group="ntf_10cols_1row") + def test_cached_10cols_1row(self, benchmark): + colnames = make_colnames(10) + rows = make_rows(10, 1) + _named_tuple_cache.clear() + named_tuple_factory(colnames, rows) + benchmark(named_tuple_factory, colnames, rows) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..1add236d56 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -34,6 +34,7 @@ from cassandra.util import OrderedDict, _sanitize_identifiers import logging + log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE @@ -49,9 +50,9 @@ Only valid when using native protocol v4+ """ -NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') -START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') -END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") _clean_name_cache = {} @@ -60,7 +61,9 @@ def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: - clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) _clean_name_cache[name] = clean return clean @@ -83,6 +86,7 @@ def tuple_factory(colnames, rows): """ return rows + class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an @@ -90,6 +94,7 @@ class PseudoNamedTupleRow(object): but otherwise do not attempt to implement the full namedtuple or iterable interface. """ + def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) @@ -104,8 +109,7 @@ def __iter__(self): return iter(self._tuple) def __repr__(self): - return '{t}({od})'.format(t=self.__class__.__name__, - od=self._dict) + return "{t}({od})".format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): @@ -113,8 +117,13 @@ def pseudo_namedtuple_factory(colnames, rows): Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ - return [PseudoNamedTupleRow(od) - for od in ordered_dict_factory(colnames, rows)] + return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] + + +# Cache namedtuple Row classes to avoid repeated exec() calls in namedtuple() +# for the same column schema. Naturally bounded by the number of distinct +# column-name tuples, which equals the number of distinct queries. +_named_tuple_cache = {} def named_tuple_factory(colnames, rows): @@ -146,32 +155,41 @@ def named_tuple_factory(colnames, rows): .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ - clean_column_names = map(_clean_column_name, colnames) + key = tuple(colnames) try: - Row = namedtuple('Row', clean_column_names) - except SyntaxError: - warnings.warn( - "Failed creating namedtuple for a result because there were too " - "many columns. This is due to a Python limitation that affects " - "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " - "created with {substitute_factory_name}, which lacks some namedtuple " - "features and is slower. To avoid slower performance accessing " - "values on row objects, Upgrade to Python 3.7, or use a different " - "row factory. (column names: {colnames})".format( - substitute_factory_name=pseudo_namedtuple_factory.__name__, - colnames=colnames + Row = _named_tuple_cache[key] + except KeyError: + clean_column_names = map(_clean_column_name, colnames) + try: + Row = namedtuple("Row", clean_column_names) + except SyntaxError: + warnings.warn( + "Failed creating namedtuple for a result because there were too " + "many columns. This is due to a Python limitation that affects " + "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " + "created with {substitute_factory_name}, which lacks some namedtuple " + "features and is slower. To avoid slower performance accessing " + "values on row objects, Upgrade to Python 3.7, or use a different " + "row factory. (column names: {colnames})".format( + substitute_factory_name=pseudo_namedtuple_factory.__name__, + colnames=colnames, + ) ) - ) - return pseudo_namedtuple_factory(colnames, rows) - except Exception: - clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt - log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " - "(see Python 'namedtuple' documentation for details on name rules). " - "Results will be returned with positional names. " - "Avoid this by choosing different names, using SELECT \"\" AS aliases, " - "or specifying a different row_factory on your Session" % - (colnames, clean_column_names)) - Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + return pseudo_namedtuple_factory(colnames, rows) + except Exception: + clean_column_names = list( + map(_clean_column_name, colnames) + ) # create list because py3 map object will be consumed by first attempt + log.warning( + "Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + 'Avoid this by choosing different names, using SELECT "" AS aliases, ' + "or specifying a different row_factory on your Session" + % (colnames, clean_column_names) + ) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) + _named_tuple_cache[key] = Row return [Row(*row) for row in rows] @@ -276,11 +294,24 @@ class Statement(object): _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False, table=None): - if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors - raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + def __init__( + self, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + table=None, + ): + if retry_policy and not hasattr( + retry_policy, "on_read_timeout" + ): # just checking one method to detect positional parameter errors + raise ValueError( + "retry_policy should implement cassandra.policies.RetryPolicy" + ) if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: @@ -329,17 +360,20 @@ def _del_routing_key(self): If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. - """) + """, + ) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL") + "or ConsistencyLevel.LOCAL_SERIAL" + ) self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): @@ -384,7 +418,8 @@ def is_lwt(self): conditional statements. .. versionadded:: 2.0.0 - """) + """, + ) class SimpleStatement(Statement): @@ -392,9 +427,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None, is_idempotent=False): + def __init__( + self, + query_string, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + ): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -402,8 +446,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + is_idempotent, + ) self._query_string = query_string @property @@ -411,9 +464,14 @@ def query_string(self): return self._query_string def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -442,7 +500,7 @@ class PreparedStatement(object): A note about * in prepared statements """ - column_metadata = None #TODO: make this bind_metadata in next major + column_metadata = None # TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None @@ -459,9 +517,19 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, - is_lwt=False, column_encryption_policy=None): + def __init__( + self, + column_metadata, + query_id, + routing_key_indexes, + query, + keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt=False, + column_encryption_policy=None, + ): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -475,13 +543,33 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self._is_lwt = is_lwt @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy=None): + def from_message( + cls, + query_id, + column_metadata, + pk_indexes, + cluster_metadata, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy=None, + ): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + return PreparedStatement( + column_metadata, + query_id, + None, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) if pk_indexes: routing_key_indexes = pk_indexes @@ -496,18 +584,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) + statement_indexes = dict( + (c.name, i) for i, c in enumerate(column_metadata) + ) # a list of which indexes in the statement correspond to partition key items try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: # we're missing a partition key component in the prepared - pass # statement; just leave routing_key_indexes as None - - return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + routing_key_indexes = [ + statement_indexes[c.name] for c in partition_key_columns + ] + except ( + KeyError + ): # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement( + column_metadata, + query_id, + routing_key_indexes, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) def bind(self, values): """ @@ -519,16 +621,23 @@ def bind(self, values): def is_routing_key_index(self, i): if self._routing_key_index_set is None: - self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + self._routing_key_index_set = ( + set(self.routing_key_indexes) if self.routing_key_indexes else set() + ) return i in self._routing_key_index_set def is_lwt(self): return self._is_lwt def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -548,9 +657,17 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + def __init__( + self, + prepared_statement, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + ): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. @@ -571,9 +688,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None self.keyspace = meta[0].keyspace_name self.table = meta[0].table_name - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, - prepared_statement.is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + prepared_statement.is_idempotent, + ) def bind(self, values): """ @@ -615,24 +740,29 @@ def bind(self, values): values.append(UNSET_VALUE) else: raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + "Column name `%s` not found in bound dict." % (col.name) + ) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( - "Too many arguments provided to bind() (got %d, expected %d)" % - (len(values), len(col_meta))) + "Too many arguments provided to bind() (got %d, expected %d)" + % (len(values), len(col_meta)) + ) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): + if ( + proto_version < 4 + and self.prepared_statement.routing_key_indexes + and value_len < len(self.prepared_statement.routing_key_indexes) + ): raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) + "Too few arguments provided to bind() (got %d, required %d for routing key)" + % (value_len, len(self.prepared_statement.routing_key_indexes)) + ) self.raw_values = values self.values = [] @@ -643,20 +773,30 @@ def bind(self, values): if proto_version >= 4: self._append_unset_value() else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + col_desc = ColDesc( + col_spec.keyspace_name, col_spec.table_name, col_spec.name + ) uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_type = ( + ce_policy.column_type(col_desc) if uses_ce else col_spec.type + ) col_bytes = col_type.serialize(value, proto_version) if uses_ce: col_bytes = ce_policy.encrypt(col_desc, col_bytes) self.values.append(col_bytes) except (TypeError, struct.error) as exc: actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + message = ( + 'Received an argument of invalid type for column "%s". ' + "Expected: %s, Got: %s; (%s)" + % (col_spec.name, col_spec.type, actual_type, exc) + ) raise TypeError(message) if proto_version >= 4: @@ -671,7 +811,10 @@ def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] - raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) + raise ValueError( + "Cannot bind UNSET_VALUE as a part of the routing key '%s'" + % col_meta.name + ) self.values.append(UNSET_VALUE) @property @@ -686,7 +829,9 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) + self._routing_key = b"".join( + self._key_parts_packed(self.values[i] for i in routing_indexes) + ) return self._routing_key @@ -694,9 +839,15 @@ def is_lwt(self): return self.prepared_statement.is_lwt() def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.prepared_statement.query_string, self.raw_values, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.prepared_statement.query_string, + self.raw_values, + consistency, + ) + __repr__ = __str__ @@ -731,7 +882,7 @@ def __str__(self): return self.name def __repr__(self): - return "BatchType.%s" % (self.name, ) + return "BatchType.%s" % (self.name,) BatchType.LOGGED = BatchType("LOGGED", 0) @@ -763,9 +914,15 @@ class BatchStatement(Statement): _session = None _is_lwt = False - def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, - session=None, custom_payload=None): + def __init__( + self, + batch_type=BatchType.LOGGED, + retry_policy=None, + consistency_level=None, + serial_consistency_level=None, + session=None, + custom_payload=None, + ): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -813,8 +970,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, self.batch_type = batch_type self._statements_and_parameters = [] self._session = session - Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + Statement.__init__( + self, + retry_policy=retry_policy, + consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, + custom_payload=custom_payload, + ) def clear(self): """ @@ -853,11 +1015,14 @@ def add(self, statement, parameters=None): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " - "to BatchStatement.add()") + "to BatchStatement.add()" + ) self._update_state(statement) if statement.is_lwt(): self._is_lwt = True - self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + self._add_statement_and_params( + True, statement.prepared_statement.query_id, statement.values + ) else: # it must be a SimpleStatement query_string = statement.query_string @@ -881,7 +1046,9 @@ def add_all(self, statements, parameters): def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: - raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + raise ValueError( + "Batch statement cannot contain more than %d statements." % 0xFFFF + ) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): @@ -907,9 +1074,15 @@ def __len__(self): return len(self._statements_and_parameters) def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.batch_type, len(self), consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return "" % ( + self.batch_type, + len(self), + consistency, + ) + __repr__ = __str__ @@ -931,7 +1104,9 @@ def __str__(self): def bind_params(query, params, encoder): if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + return query % dict( + (k, encoder.cql_encode_all_types(v)) for k, v in params.items() + ) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -940,6 +1115,7 @@ class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ + pass @@ -1000,7 +1176,9 @@ class QueryTrace(object): _session = None - _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_SESSIONS_FORMAT = ( + "SELECT * FROM system_traces.sessions WHERE session_id = %s" + ) _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 @@ -1029,18 +1207,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( - "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." + % (max_wait,) + ) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) - metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout + metadata_request_timeout = ( + self._session.cluster.control_connection + and self._session.cluster.control_connection._metadata_request_timeout + ) session_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_SESSIONS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), + (self.trace_id,), + time_spent, + max_wait, + ) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries session_row = session_results.one() if session_results else None - is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + is_complete = ( + session_row is not None + and session_row.duration is not None + and session_row.started_at is not None + ) if not session_results or (wait_for_complete and not is_complete): - time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + time.sleep(self._BASE_RETRY_SLEEP * (2**attempt)) attempt += 1 continue if is_complete: @@ -1049,29 +1245,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) if is_complete else None + self.duration = ( + timedelta(microseconds=session_row.duration) if is_complete else None + ) self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 - self.client = getattr(session_row, 'client', None) + self.client = getattr(session_row, "client", None) - log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + log.debug( + "Attempting to fetch trace events for trace ID: %s", self.trace_id + ) time_spent = time.time() - start event_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout), - consistency_level=query_cl), + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_EVENTS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), (self.trace_id,), time_spent, - max_wait) + max_wait, + ) log.debug("Fetched trace events for trace ID: %s", self.trace_id) - self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) - for r in event_results) + self.events = tuple( + TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results + ) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future( + query, parameters, trace=False, custom_payload=None, timeout=timeout + ) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() @@ -1079,12 +1288,22 @@ def _execute(self, query, parameters, time_spent, max_wait): try: return future.result() except OperationTimedOut: - raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + raise TraceUnavailable( + "Trace information was not available within %f seconds" % (max_wait,) + ) def __str__(self): - return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ - % (self.request_type, self.trace_id, self.coordinator, self.started_at, - self.duration, self.parameters) + return ( + "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" + % ( + self.request_type, + self.trace_id, + self.coordinator, + self.started_at, + self.duration, + self.parameters, + ) + ) class TraceEvent(object): @@ -1121,7 +1340,9 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) + self.datetime = datetime.fromtimestamp( + unix_time_from_uuid1(timeuuid), tz=timezone.utc + ) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -1130,7 +1351,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.thread_name = thread_name def __str__(self): - return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + return "%s on %s[%s] at %s" % ( + self.description, + self.source, + self.thread_name, + self.datetime, + ) # TODO remove next major since we can target using the `host` attribute of session.execute @@ -1139,9 +1365,12 @@ class HostTargetingStatement(object): Wraps any query statement and attaches a target host, making it usable in a targeted LBP without modifying the user's statement. """ + def __init__(self, inner_statement, target_host): - self.__class__ = type(inner_statement.__class__.__name__, - (self.__class__, inner_statement.__class__), - {}) - self.__dict__ = inner_statement.__dict__ - self.target_host = target_host + self.__class__ = type( + inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}, + ) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host