diff --git a/benchmarks/test_named_tuple_factory_benchmark.py b/benchmarks/test_named_tuple_factory_benchmark.py
new file mode 100644
index 0000000000..ed4b7be8ca
--- /dev/null
+++ b/benchmarks/test_named_tuple_factory_benchmark.py
@@ -0,0 +1,206 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Benchmarks for named_tuple_factory with and without namedtuple class caching.
+
+Run with: pytest benchmarks/test_named_tuple_factory_benchmark.py -v
+"""
+
+import re
+import warnings
+from collections import namedtuple
+
+import pytest
+
+from cassandra.query import named_tuple_factory, _named_tuple_cache
+from cassandra.util import _sanitize_identifiers
+
+
+# ---------------------------------------------------------------------------
+# Reference: original uncached implementation (copied from master)
+# ---------------------------------------------------------------------------
+
+NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]")
+START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*")
+END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$")
+
+_clean_name_cache_old = {}
+
+
+def _clean_column_name_old(name):
+ try:
+ return _clean_name_cache_old[name]
+ except KeyError:
+ clean = NON_ALPHA_REGEX.sub(
+ "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))
+ )
+ _clean_name_cache_old[name] = clean
+ return clean
+
+
+def named_tuple_factory_uncached(colnames, rows):
+ """Original implementation without caching (for benchmark comparison)."""
+ clean_column_names = map(_clean_column_name_old, colnames)
+ try:
+ Row = namedtuple("Row", clean_column_names)
+ except SyntaxError:
+ raise
+ except Exception:
+ clean_column_names = list(map(_clean_column_name_old, colnames))
+ Row = namedtuple("Row", _sanitize_identifiers(clean_column_names))
+ return [Row(*row) for row in rows]
+
+
+# ---------------------------------------------------------------------------
+# Test data generators
+# ---------------------------------------------------------------------------
+
+
+def make_colnames(n):
+ return tuple(f"col_{i}" for i in range(n))
+
+
+def make_rows(ncols, nrows):
+ return [tuple(range(ncols)) for _ in range(nrows)]
+
+
+# ---------------------------------------------------------------------------
+# Correctness tests
+# ---------------------------------------------------------------------------
+
+
+class TestNamedTupleFactoryCorrectness:
+ """Verify the cached implementation matches the uncached one."""
+
+ @pytest.mark.parametrize("ncols", [1, 5, 10, 20])
+ @pytest.mark.parametrize("nrows", [1, 10, 100])
+ def test_results_match(self, ncols, nrows):
+ colnames = make_colnames(ncols)
+ rows = make_rows(ncols, nrows)
+ _named_tuple_cache.clear()
+ cached_result = named_tuple_factory(colnames, rows)
+ uncached_result = named_tuple_factory_uncached(colnames, rows)
+ assert len(cached_result) == len(uncached_result)
+ for cr, ur in zip(cached_result, uncached_result):
+ assert tuple(cr) == tuple(ur)
+ assert cr._fields == ur._fields
+
+ def test_cache_hit_returns_same_class(self):
+ colnames = ("name", "age", "email")
+ rows1 = [("Alice", 30, "a@b.com")]
+ rows2 = [("Bob", 25, "b@c.com")]
+ _named_tuple_cache.clear()
+ result1 = named_tuple_factory(colnames, rows1)
+ result2 = named_tuple_factory(colnames, rows2)
+ # Same Row class should be reused
+ assert type(result1[0]) is type(result2[0])
+
+ def test_different_schemas_get_different_classes(self):
+ _named_tuple_cache.clear()
+ result1 = named_tuple_factory(("a", "b"), [(1, 2)])
+ result2 = named_tuple_factory(("x", "y"), [(3, 4)])
+ assert type(result1[0]) is not type(result2[0])
+ assert result1[0]._fields == ("a", "b")
+ assert result2[0]._fields == ("x", "y")
+
+
+# ---------------------------------------------------------------------------
+# Benchmarks
+# ---------------------------------------------------------------------------
+
+
+class TestNamedTupleFactoryBenchmark:
+ """Benchmark cached vs uncached named_tuple_factory."""
+
+ # --- 5 columns, 100 rows ---
+
+ @pytest.mark.benchmark(group="ntf_5cols_100rows")
+ def test_uncached_5cols_100rows(self, benchmark):
+ colnames = make_colnames(5)
+ rows = make_rows(5, 100)
+ benchmark(named_tuple_factory_uncached, colnames, rows)
+
+ @pytest.mark.benchmark(group="ntf_5cols_100rows")
+ def test_cached_5cols_100rows(self, benchmark):
+ colnames = make_colnames(5)
+ rows = make_rows(5, 100)
+ _named_tuple_cache.clear()
+ # Warm the cache with one call
+ named_tuple_factory(colnames, rows)
+ benchmark(named_tuple_factory, colnames, rows)
+
+ # --- 10 columns, 100 rows ---
+
+ @pytest.mark.benchmark(group="ntf_10cols_100rows")
+ def test_uncached_10cols_100rows(self, benchmark):
+ colnames = make_colnames(10)
+ rows = make_rows(10, 100)
+ benchmark(named_tuple_factory_uncached, colnames, rows)
+
+ @pytest.mark.benchmark(group="ntf_10cols_100rows")
+ def test_cached_10cols_100rows(self, benchmark):
+ colnames = make_colnames(10)
+ rows = make_rows(10, 100)
+ _named_tuple_cache.clear()
+ named_tuple_factory(colnames, rows)
+ benchmark(named_tuple_factory, colnames, rows)
+
+ # --- 20 columns, 100 rows ---
+
+ @pytest.mark.benchmark(group="ntf_20cols_100rows")
+ def test_uncached_20cols_100rows(self, benchmark):
+ colnames = make_colnames(20)
+ rows = make_rows(20, 100)
+ benchmark(named_tuple_factory_uncached, colnames, rows)
+
+ @pytest.mark.benchmark(group="ntf_20cols_100rows")
+ def test_cached_20cols_100rows(self, benchmark):
+ colnames = make_colnames(20)
+ rows = make_rows(20, 100)
+ _named_tuple_cache.clear()
+ named_tuple_factory(colnames, rows)
+ benchmark(named_tuple_factory, colnames, rows)
+
+ # --- 5 columns, 1000 rows ---
+
+ @pytest.mark.benchmark(group="ntf_5cols_1000rows")
+ def test_uncached_5cols_1000rows(self, benchmark):
+ colnames = make_colnames(5)
+ rows = make_rows(5, 1000)
+ benchmark(named_tuple_factory_uncached, colnames, rows)
+
+ @pytest.mark.benchmark(group="ntf_5cols_1000rows")
+ def test_cached_5cols_1000rows(self, benchmark):
+ colnames = make_colnames(5)
+ rows = make_rows(5, 1000)
+ _named_tuple_cache.clear()
+ named_tuple_factory(colnames, rows)
+ benchmark(named_tuple_factory, colnames, rows)
+
+ # --- 10 columns, 1 row (measures class creation overhead most clearly) ---
+
+ @pytest.mark.benchmark(group="ntf_10cols_1row")
+ def test_uncached_10cols_1row(self, benchmark):
+ colnames = make_colnames(10)
+ rows = make_rows(10, 1)
+ benchmark(named_tuple_factory_uncached, colnames, rows)
+
+ @pytest.mark.benchmark(group="ntf_10cols_1row")
+ def test_cached_10cols_1row(self, benchmark):
+ colnames = make_colnames(10)
+ rows = make_rows(10, 1)
+ _named_tuple_cache.clear()
+ named_tuple_factory(colnames, rows)
+ benchmark(named_tuple_factory, colnames, rows)
diff --git a/cassandra/query.py b/cassandra/query.py
index 6c6878fdb4..1add236d56 100644
--- a/cassandra/query.py
+++ b/cassandra/query.py
@@ -34,6 +34,7 @@
from cassandra.util import OrderedDict, _sanitize_identifiers
import logging
+
log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
@@ -49,9 +50,9 @@
Only valid when using native protocol v4+
"""
-NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
-START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
-END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$')
+NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]")
+START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*")
+END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$")
_clean_name_cache = {}
@@ -60,7 +61,9 @@ def _clean_column_name(name):
try:
return _clean_name_cache[name]
except KeyError:
- clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)))
+ clean = NON_ALPHA_REGEX.sub(
+ "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))
+ )
_clean_name_cache[name] = clean
return clean
@@ -83,6 +86,7 @@ def tuple_factory(colnames, rows):
"""
return rows
+
class PseudoNamedTupleRow(object):
"""
Helper class for pseudo_named_tuple_factory. These objects provide an
@@ -90,6 +94,7 @@ class PseudoNamedTupleRow(object):
but otherwise do not attempt to implement the full namedtuple or iterable
interface.
"""
+
def __init__(self, ordered_dict):
self._dict = ordered_dict
self._tuple = tuple(ordered_dict.values())
@@ -104,8 +109,7 @@ def __iter__(self):
return iter(self._tuple)
def __repr__(self):
- return '{t}({od})'.format(t=self.__class__.__name__,
- od=self._dict)
+ return "{t}({od})".format(t=self.__class__.__name__, od=self._dict)
def pseudo_namedtuple_factory(colnames, rows):
@@ -113,8 +117,13 @@ def pseudo_namedtuple_factory(colnames, rows):
Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback
factory for cases where :meth:`.named_tuple_factory` fails to create rows.
"""
- return [PseudoNamedTupleRow(od)
- for od in ordered_dict_factory(colnames, rows)]
+ return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)]
+
+
+# Cache namedtuple Row classes to avoid repeated exec() calls in namedtuple()
+# for the same column schema. Naturally bounded by the number of distinct
+# column-name tuples, which equals the number of distinct queries.
+_named_tuple_cache = {}
def named_tuple_factory(colnames, rows):
@@ -146,32 +155,41 @@ def named_tuple_factory(colnames, rows):
.. versionchanged:: 2.0.0
moved from ``cassandra.decoder`` to ``cassandra.query``
"""
- clean_column_names = map(_clean_column_name, colnames)
+ key = tuple(colnames)
try:
- Row = namedtuple('Row', clean_column_names)
- except SyntaxError:
- warnings.warn(
- "Failed creating namedtuple for a result because there were too "
- "many columns. This is due to a Python limitation that affects "
- "namedtuple in Python 3.0-3.6 (see issue18896). The row will be "
- "created with {substitute_factory_name}, which lacks some namedtuple "
- "features and is slower. To avoid slower performance accessing "
- "values on row objects, Upgrade to Python 3.7, or use a different "
- "row factory. (column names: {colnames})".format(
- substitute_factory_name=pseudo_namedtuple_factory.__name__,
- colnames=colnames
+ Row = _named_tuple_cache[key]
+ except KeyError:
+ clean_column_names = map(_clean_column_name, colnames)
+ try:
+ Row = namedtuple("Row", clean_column_names)
+ except SyntaxError:
+ warnings.warn(
+ "Failed creating namedtuple for a result because there were too "
+ "many columns. This is due to a Python limitation that affects "
+ "namedtuple in Python 3.0-3.6 (see issue18896). The row will be "
+ "created with {substitute_factory_name}, which lacks some namedtuple "
+ "features and is slower. To avoid slower performance accessing "
+ "values on row objects, Upgrade to Python 3.7, or use a different "
+ "row factory. (column names: {colnames})".format(
+ substitute_factory_name=pseudo_namedtuple_factory.__name__,
+ colnames=colnames,
+ )
)
- )
- return pseudo_namedtuple_factory(colnames, rows)
- except Exception:
- clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt
- log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) "
- "(see Python 'namedtuple' documentation for details on name rules). "
- "Results will be returned with positional names. "
- "Avoid this by choosing different names, using SELECT \"
\" AS aliases, "
- "or specifying a different row_factory on your Session" %
- (colnames, clean_column_names))
- Row = namedtuple('Row', _sanitize_identifiers(clean_column_names))
+ return pseudo_namedtuple_factory(colnames, rows)
+ except Exception:
+ clean_column_names = list(
+ map(_clean_column_name, colnames)
+ ) # create list because py3 map object will be consumed by first attempt
+ log.warning(
+ "Failed creating named tuple for results with column names %s (cleaned: %s) "
+ "(see Python 'namedtuple' documentation for details on name rules). "
+ "Results will be returned with positional names. "
+ 'Avoid this by choosing different names, using SELECT "" AS aliases, '
+ "or specifying a different row_factory on your Session"
+ % (colnames, clean_column_names)
+ )
+ Row = namedtuple("Row", _sanitize_identifiers(clean_column_names))
+ _named_tuple_cache[key] = Row
return [Row(*row) for row in rows]
@@ -276,11 +294,24 @@ class Statement(object):
_serial_consistency_level = None
_routing_key = None
- def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
- is_idempotent=False, table=None):
- if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
- raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
+ def __init__(
+ self,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ is_idempotent=False,
+ table=None,
+ ):
+ if retry_policy and not hasattr(
+ retry_policy, "on_read_timeout"
+ ): # just checking one method to detect positional parameter errors
+ raise ValueError(
+ "retry_policy should implement cassandra.policies.RetryPolicy"
+ )
if retry_policy is not None:
self.retry_policy = retry_policy
if consistency_level is not None:
@@ -329,17 +360,20 @@ def _del_routing_key(self):
If the partition key is a composite, a list or tuple must be passed in.
Each key component should be in its packed (binary) format, so all
components should be strings.
- """)
+ """,
+ )
def _get_serial_consistency_level(self):
return self._serial_consistency_level
def _set_serial_consistency_level(self, serial_consistency_level):
- if (serial_consistency_level is not None and
- not ConsistencyLevel.is_serial(serial_consistency_level)):
+ if serial_consistency_level is not None and not ConsistencyLevel.is_serial(
+ serial_consistency_level
+ ):
raise ValueError(
"serial_consistency_level must be either ConsistencyLevel.SERIAL "
- "or ConsistencyLevel.LOCAL_SERIAL")
+ "or ConsistencyLevel.LOCAL_SERIAL"
+ )
self._serial_consistency_level = serial_consistency_level
def _del_serial_consistency_level(self):
@@ -384,7 +418,8 @@ def is_lwt(self):
conditional statements.
.. versionadded:: 2.0.0
- """)
+ """,
+ )
class SimpleStatement(Statement):
@@ -392,9 +427,18 @@ class SimpleStatement(Statement):
A simple, un-prepared query.
"""
- def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
- custom_payload=None, is_idempotent=False):
+ def __init__(
+ self,
+ query_string,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ is_idempotent=False,
+ ):
"""
`query_string` should be a literal CQL statement with the exception
of parameter placeholders that will be filled through the
@@ -402,8 +446,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout
See :class:`Statement` attributes for a description of the other parameters.
"""
- Statement.__init__(self, retry_policy, consistency_level, routing_key,
- serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent)
+ Statement.__init__(
+ self,
+ retry_policy,
+ consistency_level,
+ routing_key,
+ serial_consistency_level,
+ fetch_size,
+ keyspace,
+ custom_payload,
+ is_idempotent,
+ )
self._query_string = query_string
@property
@@ -411,9 +464,14 @@ def query_string(self):
return self._query_string
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.query_string, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.query_string,
+ consistency,
+ )
+
__repr__ = __str__
@@ -442,7 +500,7 @@ class PreparedStatement(object):
A note about * in prepared statements
"""
- column_metadata = None #TODO: make this bind_metadata in next major
+ column_metadata = None # TODO: make this bind_metadata in next major
retry_policy = None
consistency_level = None
custom_payload = None
@@ -459,9 +517,19 @@ class PreparedStatement(object):
serial_consistency_level = None # TODO never used?
_is_lwt = False
- def __init__(self, column_metadata, query_id, routing_key_indexes, query,
- keyspace, protocol_version, result_metadata, result_metadata_id,
- is_lwt=False, column_encryption_policy=None):
+ def __init__(
+ self,
+ column_metadata,
+ query_id,
+ routing_key_indexes,
+ query,
+ keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt=False,
+ column_encryption_policy=None,
+ ):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -475,13 +543,33 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
self._is_lwt = is_lwt
@classmethod
- def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy=None):
+ def from_message(
+ cls,
+ query_id,
+ column_metadata,
+ pk_indexes,
+ cluster_metadata,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy=None,
+ ):
if not column_metadata:
- return PreparedStatement(column_metadata, query_id, None,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy)
+ return PreparedStatement(
+ column_metadata,
+ query_id,
+ None,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy,
+ )
if pk_indexes:
routing_key_indexes = pk_indexes
@@ -496,18 +584,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
partition_key_columns = table_meta.partition_key
# make a map of {column_name: index} for each column in the statement
- statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
+ statement_indexes = dict(
+ (c.name, i) for i, c in enumerate(column_metadata)
+ )
# a list of which indexes in the statement correspond to partition key items
try:
- routing_key_indexes = [statement_indexes[c.name]
- for c in partition_key_columns]
- except KeyError: # we're missing a partition key component in the prepared
- pass # statement; just leave routing_key_indexes as None
-
- return PreparedStatement(column_metadata, query_id, routing_key_indexes,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy)
+ routing_key_indexes = [
+ statement_indexes[c.name] for c in partition_key_columns
+ ]
+ except (
+ KeyError
+ ): # we're missing a partition key component in the prepared
+ pass # statement; just leave routing_key_indexes as None
+
+ return PreparedStatement(
+ column_metadata,
+ query_id,
+ routing_key_indexes,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy,
+ )
def bind(self, values):
"""
@@ -519,16 +621,23 @@ def bind(self, values):
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
- self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
+ self._routing_key_index_set = (
+ set(self.routing_key_indexes) if self.routing_key_indexes else set()
+ )
return i in self._routing_key_index_set
def is_lwt(self):
return self._is_lwt
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.query_string, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.query_string,
+ consistency,
+ )
+
__repr__ = __str__
@@ -548,9 +657,17 @@ class BoundStatement(Statement):
The sequence of values that were bound to the prepared statement.
"""
- def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
- custom_payload=None):
+ def __init__(
+ self,
+ prepared_statement,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ ):
"""
`prepared_statement` should be an instance of :class:`PreparedStatement`.
@@ -571,9 +688,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None
self.keyspace = meta[0].keyspace_name
self.table = meta[0].table_name
- Statement.__init__(self, retry_policy, consistency_level, routing_key,
- serial_consistency_level, fetch_size, keyspace, custom_payload,
- prepared_statement.is_idempotent)
+ Statement.__init__(
+ self,
+ retry_policy,
+ consistency_level,
+ routing_key,
+ serial_consistency_level,
+ fetch_size,
+ keyspace,
+ custom_payload,
+ prepared_statement.is_idempotent,
+ )
def bind(self, values):
"""
@@ -615,24 +740,29 @@ def bind(self, values):
values.append(UNSET_VALUE)
else:
raise KeyError(
- 'Column name `%s` not found in bound dict.' %
- (col.name))
+ "Column name `%s` not found in bound dict." % (col.name)
+ )
value_len = len(values)
col_meta_len = len(col_meta)
if value_len > col_meta_len:
raise ValueError(
- "Too many arguments provided to bind() (got %d, expected %d)" %
- (len(values), len(col_meta)))
+ "Too many arguments provided to bind() (got %d, expected %d)"
+ % (len(values), len(col_meta))
+ )
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
# the error will be better reported when UNSET_VALUE is implicitly added.
- if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
- value_len < len(self.prepared_statement.routing_key_indexes):
+ if (
+ proto_version < 4
+ and self.prepared_statement.routing_key_indexes
+ and value_len < len(self.prepared_statement.routing_key_indexes)
+ ):
raise ValueError(
- "Too few arguments provided to bind() (got %d, required %d for routing key)" %
- (value_len, len(self.prepared_statement.routing_key_indexes)))
+ "Too few arguments provided to bind() (got %d, required %d for routing key)"
+ % (value_len, len(self.prepared_statement.routing_key_indexes))
+ )
self.raw_values = values
self.values = []
@@ -643,20 +773,30 @@ def bind(self, values):
if proto_version >= 4:
self._append_unset_value()
else:
- raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
+ raise ValueError(
+ "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)"
+ % proto_version
+ )
else:
try:
- col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name)
+ col_desc = ColDesc(
+ col_spec.keyspace_name, col_spec.table_name, col_spec.name
+ )
uses_ce = ce_policy and ce_policy.contains_column(col_desc)
- col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type
+ col_type = (
+ ce_policy.column_type(col_desc) if uses_ce else col_spec.type
+ )
col_bytes = col_type.serialize(value, proto_version)
if uses_ce:
col_bytes = ce_policy.encrypt(col_desc, col_bytes)
self.values.append(col_bytes)
except (TypeError, struct.error) as exc:
actual_type = type(value)
- message = ('Received an argument of invalid type for column "%s". '
- 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
+ message = (
+ 'Received an argument of invalid type for column "%s". '
+ "Expected: %s, Got: %s; (%s)"
+ % (col_spec.name, col_spec.type, actual_type, exc)
+ )
raise TypeError(message)
if proto_version >= 4:
@@ -671,7 +811,10 @@ def _append_unset_value(self):
next_index = len(self.values)
if self.prepared_statement.is_routing_key_index(next_index):
col_meta = self.prepared_statement.column_metadata[next_index]
- raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
+ raise ValueError(
+ "Cannot bind UNSET_VALUE as a part of the routing key '%s'"
+ % col_meta.name
+ )
self.values.append(UNSET_VALUE)
@property
@@ -686,7 +829,9 @@ def routing_key(self):
if len(routing_indexes) == 1:
self._routing_key = self.values[routing_indexes[0]]
else:
- self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes))
+ self._routing_key = b"".join(
+ self._key_parts_packed(self.values[i] for i in routing_indexes)
+ )
return self._routing_key
@@ -694,9 +839,15 @@ def is_lwt(self):
return self.prepared_statement.is_lwt()
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.prepared_statement.query_string, self.raw_values, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.prepared_statement.query_string,
+ self.raw_values,
+ consistency,
+ )
+
__repr__ = __str__
@@ -731,7 +882,7 @@ def __str__(self):
return self.name
def __repr__(self):
- return "BatchType.%s" % (self.name, )
+ return "BatchType.%s" % (self.name,)
BatchType.LOGGED = BatchType("LOGGED", 0)
@@ -763,9 +914,15 @@ class BatchStatement(Statement):
_session = None
_is_lwt = False
- def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
- consistency_level=None, serial_consistency_level=None,
- session=None, custom_payload=None):
+ def __init__(
+ self,
+ batch_type=BatchType.LOGGED,
+ retry_policy=None,
+ consistency_level=None,
+ serial_consistency_level=None,
+ session=None,
+ custom_payload=None,
+ ):
"""
`batch_type` specifies The :class:`.BatchType` for the batch operation.
Defaults to :attr:`.BatchType.LOGGED`.
@@ -813,8 +970,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
self.batch_type = batch_type
self._statements_and_parameters = []
self._session = session
- Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
- serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
+ Statement.__init__(
+ self,
+ retry_policy=retry_policy,
+ consistency_level=consistency_level,
+ serial_consistency_level=serial_consistency_level,
+ custom_payload=custom_payload,
+ )
def clear(self):
"""
@@ -853,11 +1015,14 @@ def add(self, statement, parameters=None):
if parameters:
raise ValueError(
"Parameters cannot be passed with a BoundStatement "
- "to BatchStatement.add()")
+ "to BatchStatement.add()"
+ )
self._update_state(statement)
if statement.is_lwt():
self._is_lwt = True
- self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
+ self._add_statement_and_params(
+ True, statement.prepared_statement.query_id, statement.values
+ )
else:
# it must be a SimpleStatement
query_string = statement.query_string
@@ -881,7 +1046,9 @@ def add_all(self, statements, parameters):
def _add_statement_and_params(self, is_prepared, statement, parameters):
if len(self._statements_and_parameters) >= 0xFFFF:
- raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF)
+ raise ValueError(
+ "Batch statement cannot contain more than %d statements." % 0xFFFF
+ )
self._statements_and_parameters.append((is_prepared, statement, parameters))
def _maybe_set_routing_attributes(self, statement):
@@ -907,9 +1074,15 @@ def __len__(self):
return len(self._statements_and_parameters)
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.batch_type, len(self), consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return "" % (
+ self.batch_type,
+ len(self),
+ consistency,
+ )
+
__repr__ = __str__
@@ -931,7 +1104,9 @@ def __str__(self):
def bind_params(query, params, encoder):
if isinstance(params, dict):
- return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items())
+ return query % dict(
+ (k, encoder.cql_encode_all_types(v)) for k, v in params.items()
+ )
else:
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
@@ -940,6 +1115,7 @@ class TraceUnavailable(Exception):
"""
Raised when complete trace details cannot be fetched from Cassandra.
"""
+
pass
@@ -1000,7 +1176,9 @@ class QueryTrace(object):
_session = None
- _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s"
+ _SELECT_SESSIONS_FORMAT = (
+ "SELECT * FROM system_traces.sessions WHERE session_id = %s"
+ )
_SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s"
_BASE_RETRY_SLEEP = 0.003
@@ -1029,18 +1207,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable(
- "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,))
+ "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait."
+ % (max_wait,)
+ )
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
- metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout
+ metadata_request_timeout = (
+ self._session.cluster.control_connection
+ and self._session.cluster.control_connection._metadata_request_timeout
+ )
session_results = self._execute(
- SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait)
+ SimpleStatement(
+ maybe_add_timeout_to_query(
+ self._SELECT_SESSIONS_FORMAT, metadata_request_timeout
+ ),
+ consistency_level=query_cl,
+ ),
+ (self.trace_id,),
+ time_spent,
+ max_wait,
+ )
# PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries
session_row = session_results.one() if session_results else None
- is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None
+ is_complete = (
+ session_row is not None
+ and session_row.duration is not None
+ and session_row.started_at is not None
+ )
if not session_results or (wait_for_complete and not is_complete):
- time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
+ time.sleep(self._BASE_RETRY_SLEEP * (2**attempt))
attempt += 1
continue
if is_complete:
@@ -1049,29 +1245,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
log.debug("Fetching parital trace info for trace ID: %s", self.trace_id)
self.request_type = session_row.request
- self.duration = timedelta(microseconds=session_row.duration) if is_complete else None
+ self.duration = (
+ timedelta(microseconds=session_row.duration) if is_complete else None
+ )
self.started_at = session_row.started_at
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
# since C* 2.2
- self.client = getattr(session_row, 'client', None)
+ self.client = getattr(session_row, "client", None)
- log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
+ log.debug(
+ "Attempting to fetch trace events for trace ID: %s", self.trace_id
+ )
time_spent = time.time() - start
event_results = self._execute(
- SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout),
- consistency_level=query_cl),
+ SimpleStatement(
+ maybe_add_timeout_to_query(
+ self._SELECT_EVENTS_FORMAT, metadata_request_timeout
+ ),
+ consistency_level=query_cl,
+ ),
(self.trace_id,),
time_spent,
- max_wait)
+ max_wait,
+ )
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
- self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
- for r in event_results)
+ self.events = tuple(
+ TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
+ for r in event_results
+ )
break
def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
- future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
+ future = self._session._create_response_future(
+ query, parameters, trace=False, custom_payload=None, timeout=timeout
+ )
# in case the user switched the row factory, set it to namedtuple for this query
future.row_factory = named_tuple_factory
future.send_request()
@@ -1079,12 +1288,22 @@ def _execute(self, query, parameters, time_spent, max_wait):
try:
return future.result()
except OperationTimedOut:
- raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
+ raise TraceUnavailable(
+ "Trace information was not available within %f seconds" % (max_wait,)
+ )
def __str__(self):
- return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \
- % (self.request_type, self.trace_id, self.coordinator, self.started_at,
- self.duration, self.parameters)
+ return (
+ "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s"
+ % (
+ self.request_type,
+ self.trace_id,
+ self.coordinator,
+ self.started_at,
+ self.duration,
+ self.parameters,
+ )
+ )
class TraceEvent(object):
@@ -1121,7 +1340,9 @@ class TraceEvent(object):
def __init__(self, description, timeuuid, source, source_elapsed, thread_name):
self.description = description
- self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc)
+ self.datetime = datetime.fromtimestamp(
+ unix_time_from_uuid1(timeuuid), tz=timezone.utc
+ )
self.source = source
if source_elapsed is not None:
self.source_elapsed = timedelta(microseconds=source_elapsed)
@@ -1130,7 +1351,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name):
self.thread_name = thread_name
def __str__(self):
- return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime)
+ return "%s on %s[%s] at %s" % (
+ self.description,
+ self.source,
+ self.thread_name,
+ self.datetime,
+ )
# TODO remove next major since we can target using the `host` attribute of session.execute
@@ -1139,9 +1365,12 @@ class HostTargetingStatement(object):
Wraps any query statement and attaches a target host, making
it usable in a targeted LBP without modifying the user's statement.
"""
+
def __init__(self, inner_statement, target_host):
- self.__class__ = type(inner_statement.__class__.__name__,
- (self.__class__, inner_statement.__class__),
- {})
- self.__dict__ = inner_statement.__dict__
- self.target_host = target_host
+ self.__class__ = type(
+ inner_statement.__class__.__name__,
+ (self.__class__, inner_statement.__class__),
+ {},
+ )
+ self.__dict__ = inner_statement.__dict__
+ self.target_host = target_host