diff --git a/cassandra/query.py b/cassandra/query.py
index 6c6878fdb4..af9ceff0eb 100644
--- a/cassandra/query.py
+++ b/cassandra/query.py
@@ -33,7 +33,15 @@
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict, _sanitize_identifiers
+try:
+ from cassandra.serializers import make_serializers as _cython_make_serializers
+
+ _HAVE_CYTHON_SERIALIZERS = True
+except ImportError:
+ _HAVE_CYTHON_SERIALIZERS = False
+
import logging
+
log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
@@ -49,9 +57,9 @@
Only valid when using native protocol v4+
"""
-NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
-START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
-END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$')
+NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]")
+START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*")
+END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$")
_clean_name_cache = {}
@@ -60,7 +68,9 @@ def _clean_column_name(name):
try:
return _clean_name_cache[name]
except KeyError:
- clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)))
+ clean = NON_ALPHA_REGEX.sub(
+ "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))
+ )
_clean_name_cache[name] = clean
return clean
@@ -83,6 +93,7 @@ def tuple_factory(colnames, rows):
"""
return rows
+
class PseudoNamedTupleRow(object):
"""
Helper class for pseudo_named_tuple_factory. These objects provide an
@@ -90,6 +101,7 @@ class PseudoNamedTupleRow(object):
but otherwise do not attempt to implement the full namedtuple or iterable
interface.
"""
+
def __init__(self, ordered_dict):
self._dict = ordered_dict
self._tuple = tuple(ordered_dict.values())
@@ -104,8 +116,7 @@ def __iter__(self):
return iter(self._tuple)
def __repr__(self):
- return '{t}({od})'.format(t=self.__class__.__name__,
- od=self._dict)
+ return "{t}({od})".format(t=self.__class__.__name__, od=self._dict)
def pseudo_namedtuple_factory(colnames, rows):
@@ -113,8 +124,7 @@ def pseudo_namedtuple_factory(colnames, rows):
Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback
factory for cases where :meth:`.named_tuple_factory` fails to create rows.
"""
- return [PseudoNamedTupleRow(od)
- for od in ordered_dict_factory(colnames, rows)]
+ return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)]
def named_tuple_factory(colnames, rows):
@@ -148,7 +158,7 @@ def named_tuple_factory(colnames, rows):
"""
clean_column_names = map(_clean_column_name, colnames)
try:
- Row = namedtuple('Row', clean_column_names)
+ Row = namedtuple("Row", clean_column_names)
except SyntaxError:
warnings.warn(
"Failed creating namedtuple for a result because there were too "
@@ -159,19 +169,23 @@ def named_tuple_factory(colnames, rows):
"values on row objects, Upgrade to Python 3.7, or use a different "
"row factory. (column names: {colnames})".format(
substitute_factory_name=pseudo_namedtuple_factory.__name__,
- colnames=colnames
+ colnames=colnames,
)
)
return pseudo_namedtuple_factory(colnames, rows)
except Exception:
- clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt
- log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) "
- "(see Python 'namedtuple' documentation for details on name rules). "
- "Results will be returned with positional names. "
- "Avoid this by choosing different names, using SELECT \"
\" AS aliases, "
- "or specifying a different row_factory on your Session" %
- (colnames, clean_column_names))
- Row = namedtuple('Row', _sanitize_identifiers(clean_column_names))
+ clean_column_names = list(
+ map(_clean_column_name, colnames)
+ ) # create list because py3 map object will be consumed by first attempt
+ log.warning(
+ "Failed creating named tuple for results with column names %s (cleaned: %s) "
+ "(see Python 'namedtuple' documentation for details on name rules). "
+ "Results will be returned with positional names. "
+ 'Avoid this by choosing different names, using SELECT "" AS aliases, '
+ "or specifying a different row_factory on your Session"
+ % (colnames, clean_column_names)
+ )
+ Row = namedtuple("Row", _sanitize_identifiers(clean_column_names))
return [Row(*row) for row in rows]
@@ -276,11 +290,24 @@ class Statement(object):
_serial_consistency_level = None
_routing_key = None
- def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
- is_idempotent=False, table=None):
- if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
- raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
+ def __init__(
+ self,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ is_idempotent=False,
+ table=None,
+ ):
+ if retry_policy and not hasattr(
+ retry_policy, "on_read_timeout"
+ ): # just checking one method to detect positional parameter errors
+ raise ValueError(
+ "retry_policy should implement cassandra.policies.RetryPolicy"
+ )
if retry_policy is not None:
self.retry_policy = retry_policy
if consistency_level is not None:
@@ -329,17 +356,20 @@ def _del_routing_key(self):
If the partition key is a composite, a list or tuple must be passed in.
Each key component should be in its packed (binary) format, so all
components should be strings.
- """)
+ """,
+ )
def _get_serial_consistency_level(self):
return self._serial_consistency_level
def _set_serial_consistency_level(self, serial_consistency_level):
- if (serial_consistency_level is not None and
- not ConsistencyLevel.is_serial(serial_consistency_level)):
+ if serial_consistency_level is not None and not ConsistencyLevel.is_serial(
+ serial_consistency_level
+ ):
raise ValueError(
"serial_consistency_level must be either ConsistencyLevel.SERIAL "
- "or ConsistencyLevel.LOCAL_SERIAL")
+ "or ConsistencyLevel.LOCAL_SERIAL"
+ )
self._serial_consistency_level = serial_consistency_level
def _del_serial_consistency_level(self):
@@ -384,7 +414,8 @@ def is_lwt(self):
conditional statements.
.. versionadded:: 2.0.0
- """)
+ """,
+ )
class SimpleStatement(Statement):
@@ -392,9 +423,18 @@ class SimpleStatement(Statement):
A simple, un-prepared query.
"""
- def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
- custom_payload=None, is_idempotent=False):
+ def __init__(
+ self,
+ query_string,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ is_idempotent=False,
+ ):
"""
`query_string` should be a literal CQL statement with the exception
of parameter placeholders that will be filled through the
@@ -402,8 +442,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout
See :class:`Statement` attributes for a description of the other parameters.
"""
- Statement.__init__(self, retry_policy, consistency_level, routing_key,
- serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent)
+ Statement.__init__(
+ self,
+ retry_policy,
+ consistency_level,
+ routing_key,
+ serial_consistency_level,
+ fetch_size,
+ keyspace,
+ custom_payload,
+ is_idempotent,
+ )
self._query_string = query_string
@property
@@ -411,9 +460,14 @@ def query_string(self):
return self._query_string
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.query_string, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.query_string,
+ consistency,
+ )
+
__repr__ = __str__
@@ -442,7 +496,7 @@ class PreparedStatement(object):
A note about * in prepared statements
"""
- column_metadata = None #TODO: make this bind_metadata in next major
+ column_metadata = None # TODO: make this bind_metadata in next major
retry_policy = None
consistency_level = None
custom_payload = None
@@ -459,9 +513,19 @@ class PreparedStatement(object):
serial_consistency_level = None # TODO never used?
_is_lwt = False
- def __init__(self, column_metadata, query_id, routing_key_indexes, query,
- keyspace, protocol_version, result_metadata, result_metadata_id,
- is_lwt=False, column_encryption_policy=None):
+ def __init__(
+ self,
+ column_metadata,
+ query_id,
+ routing_key_indexes,
+ query,
+ keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt=False,
+ column_encryption_policy=None,
+ ):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -474,14 +538,61 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
self.is_idempotent = False
self._is_lwt = is_lwt
+ @property
+ def _serializers(self):
+ """Lazily create and cache Cython serializers for column types.
+
+ Returns a list of Serializer objects if Cython serializers are available
+ and there is no column encryption policy, otherwise returns None.
+
+ The column_encryption_policy check is performed on every access (not
+ cached) so that serializers are correctly bypassed if a policy is set
+ after construction. This means the cache never goes stale: once a CE
+ policy is present, we always return None and fall through to the
+ encryption-aware bind path.
+ """
+ if self.column_encryption_policy:
+ return None
+ try:
+ return self._cached_serializers
+ except AttributeError:
+ pass
+ if _HAVE_CYTHON_SERIALIZERS and self.column_metadata:
+ self._cached_serializers = _cython_make_serializers(
+ [col.type for col in self.column_metadata]
+ )
+ else:
+ self._cached_serializers = None
+ return self._cached_serializers
+
@classmethod
- def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy=None):
+ def from_message(
+ cls,
+ query_id,
+ column_metadata,
+ pk_indexes,
+ cluster_metadata,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy=None,
+ ):
if not column_metadata:
- return PreparedStatement(column_metadata, query_id, None,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy)
+ return PreparedStatement(
+ column_metadata,
+ query_id,
+ None,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy,
+ )
if pk_indexes:
routing_key_indexes = pk_indexes
@@ -496,18 +607,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
partition_key_columns = table_meta.partition_key
# make a map of {column_name: index} for each column in the statement
- statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
+ statement_indexes = dict(
+ (c.name, i) for i, c in enumerate(column_metadata)
+ )
# a list of which indexes in the statement correspond to partition key items
try:
- routing_key_indexes = [statement_indexes[c.name]
- for c in partition_key_columns]
- except KeyError: # we're missing a partition key component in the prepared
- pass # statement; just leave routing_key_indexes as None
-
- return PreparedStatement(column_metadata, query_id, routing_key_indexes,
- query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, is_lwt, column_encryption_policy)
+ routing_key_indexes = [
+ statement_indexes[c.name] for c in partition_key_columns
+ ]
+ except (
+ KeyError
+ ): # we're missing a partition key component in the prepared
+ pass # statement; just leave routing_key_indexes as None
+
+ return PreparedStatement(
+ column_metadata,
+ query_id,
+ routing_key_indexes,
+ query,
+ prepared_keyspace,
+ protocol_version,
+ result_metadata,
+ result_metadata_id,
+ is_lwt,
+ column_encryption_policy,
+ )
def bind(self, values):
"""
@@ -519,19 +644,54 @@ def bind(self, values):
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
- self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
+ self._routing_key_index_set = (
+ set(self.routing_key_indexes) if self.routing_key_indexes else set()
+ )
return i in self._routing_key_index_set
def is_lwt(self):
return self._is_lwt
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.query_string, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.query_string,
+ consistency,
+ )
+
__repr__ = __str__
+def _raise_bind_serialize_error(col_spec, value, exc):
+ """Wrap TypeError, struct.error, or OverflowError with column context.
+
+ Called from all three bind loop paths (CE, Cython, plain Python) to
+ provide a uniform error message that includes the column name and
+ expected type. The message distinguishes between wrong-type and
+ out-of-range scenarios:
+
+ - TypeError -> "invalid type"
+ - OverflowError -> "value out of range"
+ - struct.error -> "value out of range"
+
+ Other exception types (e.g. ValueError from VectorType dimension
+ mismatch) propagate without wrapping.
+ """
+ actual_type = type(value)
+ if isinstance(exc, (OverflowError, struct.error)):
+ reason = "value out of range"
+ else:
+ reason = "invalid type"
+ message = (
+ 'Received an argument with %s for column "%s". '
+ "Expected: %s, Got: %s; (%s)"
+ % (reason, col_spec.name, col_spec.type, actual_type, exc)
+ )
+ raise TypeError(message) from exc
+
+
class BoundStatement(Statement):
"""
A prepared statement that has been bound to a particular set of values.
@@ -548,9 +708,17 @@ class BoundStatement(Statement):
The sequence of values that were bound to the prepared statement.
"""
- def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None,
- serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
- custom_payload=None):
+ def __init__(
+ self,
+ prepared_statement,
+ retry_policy=None,
+ consistency_level=None,
+ routing_key=None,
+ serial_consistency_level=None,
+ fetch_size=FETCH_SIZE_UNSET,
+ keyspace=None,
+ custom_payload=None,
+ ):
"""
`prepared_statement` should be an instance of :class:`PreparedStatement`.
@@ -571,9 +739,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None
self.keyspace = meta[0].keyspace_name
self.table = meta[0].table_name
- Statement.__init__(self, retry_policy, consistency_level, routing_key,
- serial_consistency_level, fetch_size, keyspace, custom_payload,
- prepared_statement.is_idempotent)
+ Statement.__init__(
+ self,
+ retry_policy,
+ consistency_level,
+ routing_key,
+ serial_consistency_level,
+ fetch_size,
+ keyspace,
+ custom_payload,
+ prepared_statement.is_idempotent,
+ )
def bind(self, values):
"""
@@ -615,64 +791,132 @@ def bind(self, values):
values.append(UNSET_VALUE)
else:
raise KeyError(
- 'Column name `%s` not found in bound dict.' %
- (col.name))
+ "Column name `%s` not found in bound dict." % (col.name)
+ )
value_len = len(values)
col_meta_len = len(col_meta)
if value_len > col_meta_len:
raise ValueError(
- "Too many arguments provided to bind() (got %d, expected %d)" %
- (len(values), len(col_meta)))
+ "Too many arguments provided to bind() (got %d, expected %d)"
+ % (len(values), len(col_meta))
+ )
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
# the error will be better reported when UNSET_VALUE is implicitly added.
- if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
- value_len < len(self.prepared_statement.routing_key_indexes):
+ if (
+ proto_version < 4
+ and self.prepared_statement.routing_key_indexes
+ and value_len < len(self.prepared_statement.routing_key_indexes)
+ ):
raise ValueError(
- "Too few arguments provided to bind() (got %d, required %d for routing key)" %
- (value_len, len(self.prepared_statement.routing_key_indexes)))
+ "Too few arguments provided to bind() (got %d, required %d for routing key)"
+ % (value_len, len(self.prepared_statement.routing_key_indexes))
+ )
self.raw_values = values
- self.values = []
- for value, col_spec in zip(values, col_meta):
- if value is None:
- self.values.append(None)
- elif value is UNSET_VALUE:
- if proto_version >= 4:
- self._append_unset_value()
+ # Pre-allocate to avoid repeated list growth reallocations
+ self.values = [None] * col_meta_len
+ idx = 0
+ if ce_policy:
+ # Column encryption path: check each column for CE policy
+ for value, col_spec in zip(values, col_meta):
+ if value is None:
+ self.values[idx] = None
+ elif value is UNSET_VALUE:
+ if proto_version >= 4:
+ idx = self._append_unset_value(idx)
+ continue
+ else:
+ raise ValueError(
+ "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)"
+ % proto_version
+ )
else:
- raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
+ try:
+ col_desc = ColDesc(
+ col_spec.keyspace_name, col_spec.table_name, col_spec.name
+ )
+ uses_ce = ce_policy.contains_column(col_desc)
+ if uses_ce:
+ col_type = ce_policy.column_type(col_desc)
+ col_bytes = col_type.serialize(value, proto_version)
+ col_bytes = ce_policy.encrypt(col_desc, col_bytes)
+ else:
+ col_bytes = col_spec.type.serialize(value, proto_version)
+ self.values[idx] = col_bytes
+ # struct.error: int32 out-of-range; OverflowError: float out-of-range
+ except (TypeError, struct.error, OverflowError) as exc:
+ _raise_bind_serialize_error(col_spec, value, exc)
+ idx += 1
+ else:
+ # Fast path: no column encryption, use Cython serializers if available
+ serializers = self.prepared_statement._serializers
+ if serializers is not None:
+ for ser, value, col_spec in zip(serializers, values, col_meta):
+ if value is None:
+ self.values[idx] = None
+ elif value is UNSET_VALUE:
+ if proto_version >= 4:
+ idx = self._append_unset_value(idx)
+ continue
+ else:
+ raise ValueError(
+ "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)"
+ % proto_version
+ )
+ else:
+ try:
+ col_bytes = ser.serialize(value, proto_version)
+ self.values[idx] = col_bytes
+ # struct.error: int32 out-of-range; OverflowError: float out-of-range
+ except (TypeError, struct.error, OverflowError) as exc:
+ _raise_bind_serialize_error(col_spec, value, exc)
+ idx += 1
else:
- try:
- col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name)
- uses_ce = ce_policy and ce_policy.contains_column(col_desc)
- col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type
- col_bytes = col_type.serialize(value, proto_version)
- if uses_ce:
- col_bytes = ce_policy.encrypt(col_desc, col_bytes)
- self.values.append(col_bytes)
- except (TypeError, struct.error) as exc:
- actual_type = type(value)
- message = ('Received an argument of invalid type for column "%s". '
- 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
- raise TypeError(message)
+ for value, col_spec in zip(values, col_meta):
+ if value is None:
+ self.values[idx] = None
+ elif value is UNSET_VALUE:
+ if proto_version >= 4:
+ idx = self._append_unset_value(idx)
+ continue
+ else:
+ raise ValueError(
+ "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)"
+ % proto_version
+ )
+ else:
+ try:
+ col_bytes = col_spec.type.serialize(value, proto_version)
+ self.values[idx] = col_bytes
+ # struct.error: int32 out-of-range; OverflowError: float out-of-range
+ except (TypeError, struct.error, OverflowError) as exc:
+ _raise_bind_serialize_error(col_spec, value, exc)
+ idx += 1
if proto_version >= 4:
- diff = col_meta_len - len(self.values)
- if diff:
- for _ in range(diff):
- self._append_unset_value()
+ # Fill remaining unbound columns with UNSET_VALUE (v4+ feature).
+ # The pre-allocated list already has slots for these, so index
+ # assignment works directly without trimming first.
+ while idx < col_meta_len:
+ idx = self._append_unset_value(idx)
+ elif idx < col_meta_len:
+ # Pre-v4: trim trailing unused slots (no UNSET_VALUE support)
+ self.values = self.values[:idx]
return self
- def _append_unset_value(self):
- next_index = len(self.values)
- if self.prepared_statement.is_routing_key_index(next_index):
- col_meta = self.prepared_statement.column_metadata[next_index]
- raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
- self.values.append(UNSET_VALUE)
+ def _append_unset_value(self, idx):
+ if self.prepared_statement.is_routing_key_index(idx):
+ col_meta = self.prepared_statement.column_metadata[idx]
+ raise ValueError(
+ "Cannot bind UNSET_VALUE as a part of the routing key '%s'"
+ % col_meta.name
+ )
+ self.values[idx] = UNSET_VALUE
+ return idx + 1
@property
def routing_key(self):
@@ -686,7 +930,9 @@ def routing_key(self):
if len(routing_indexes) == 1:
self._routing_key = self.values[routing_indexes[0]]
else:
- self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes))
+ self._routing_key = b"".join(
+ self._key_parts_packed(self.values[i] for i in routing_indexes)
+ )
return self._routing_key
@@ -694,9 +940,15 @@ def is_lwt(self):
return self.prepared_statement.is_lwt()
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.prepared_statement.query_string, self.raw_values, consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return '' % (
+ self.prepared_statement.query_string,
+ self.raw_values,
+ consistency,
+ )
+
__repr__ = __str__
@@ -731,7 +983,7 @@ def __str__(self):
return self.name
def __repr__(self):
- return "BatchType.%s" % (self.name, )
+ return "BatchType.%s" % (self.name,)
BatchType.LOGGED = BatchType("LOGGED", 0)
@@ -763,9 +1015,15 @@ class BatchStatement(Statement):
_session = None
_is_lwt = False
- def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
- consistency_level=None, serial_consistency_level=None,
- session=None, custom_payload=None):
+ def __init__(
+ self,
+ batch_type=BatchType.LOGGED,
+ retry_policy=None,
+ consistency_level=None,
+ serial_consistency_level=None,
+ session=None,
+ custom_payload=None,
+ ):
"""
`batch_type` specifies The :class:`.BatchType` for the batch operation.
Defaults to :attr:`.BatchType.LOGGED`.
@@ -813,8 +1071,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
self.batch_type = batch_type
self._statements_and_parameters = []
self._session = session
- Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
- serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
+ Statement.__init__(
+ self,
+ retry_policy=retry_policy,
+ consistency_level=consistency_level,
+ serial_consistency_level=serial_consistency_level,
+ custom_payload=custom_payload,
+ )
def clear(self):
"""
@@ -853,11 +1116,14 @@ def add(self, statement, parameters=None):
if parameters:
raise ValueError(
"Parameters cannot be passed with a BoundStatement "
- "to BatchStatement.add()")
+ "to BatchStatement.add()"
+ )
self._update_state(statement)
if statement.is_lwt():
self._is_lwt = True
- self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
+ self._add_statement_and_params(
+ True, statement.prepared_statement.query_id, statement.values
+ )
else:
# it must be a SimpleStatement
query_string = statement.query_string
@@ -881,7 +1147,9 @@ def add_all(self, statements, parameters):
def _add_statement_and_params(self, is_prepared, statement, parameters):
if len(self._statements_and_parameters) >= 0xFFFF:
- raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF)
+ raise ValueError(
+ "Batch statement cannot contain more than %d statements." % 0xFFFF
+ )
self._statements_and_parameters.append((is_prepared, statement, parameters))
def _maybe_set_routing_attributes(self, statement):
@@ -907,9 +1175,15 @@ def __len__(self):
return len(self._statements_and_parameters)
def __str__(self):
- consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
- return (u'' %
- (self.batch_type, len(self), consistency))
+ consistency = ConsistencyLevel.value_to_name.get(
+ self.consistency_level, "Not Set"
+ )
+ return "" % (
+ self.batch_type,
+ len(self),
+ consistency,
+ )
+
__repr__ = __str__
@@ -931,7 +1205,9 @@ def __str__(self):
def bind_params(query, params, encoder):
if isinstance(params, dict):
- return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items())
+ return query % dict(
+ (k, encoder.cql_encode_all_types(v)) for k, v in params.items()
+ )
else:
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
@@ -940,6 +1216,7 @@ class TraceUnavailable(Exception):
"""
Raised when complete trace details cannot be fetched from Cassandra.
"""
+
pass
@@ -1000,7 +1277,9 @@ class QueryTrace(object):
_session = None
- _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s"
+ _SELECT_SESSIONS_FORMAT = (
+ "SELECT * FROM system_traces.sessions WHERE session_id = %s"
+ )
_SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s"
_BASE_RETRY_SLEEP = 0.003
@@ -1029,18 +1308,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable(
- "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,))
+ "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait."
+ % (max_wait,)
+ )
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
- metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout
+ metadata_request_timeout = (
+ self._session.cluster.control_connection
+ and self._session.cluster.control_connection._metadata_request_timeout
+ )
session_results = self._execute(
- SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait)
+ SimpleStatement(
+ maybe_add_timeout_to_query(
+ self._SELECT_SESSIONS_FORMAT, metadata_request_timeout
+ ),
+ consistency_level=query_cl,
+ ),
+ (self.trace_id,),
+ time_spent,
+ max_wait,
+ )
# PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries
session_row = session_results.one() if session_results else None
- is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None
+ is_complete = (
+ session_row is not None
+ and session_row.duration is not None
+ and session_row.started_at is not None
+ )
if not session_results or (wait_for_complete and not is_complete):
- time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
+ time.sleep(self._BASE_RETRY_SLEEP * (2**attempt))
attempt += 1
continue
if is_complete:
@@ -1049,29 +1346,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
log.debug("Fetching parital trace info for trace ID: %s", self.trace_id)
self.request_type = session_row.request
- self.duration = timedelta(microseconds=session_row.duration) if is_complete else None
+ self.duration = (
+ timedelta(microseconds=session_row.duration) if is_complete else None
+ )
self.started_at = session_row.started_at
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
# since C* 2.2
- self.client = getattr(session_row, 'client', None)
+ self.client = getattr(session_row, "client", None)
- log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
+ log.debug(
+ "Attempting to fetch trace events for trace ID: %s", self.trace_id
+ )
time_spent = time.time() - start
event_results = self._execute(
- SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout),
- consistency_level=query_cl),
+ SimpleStatement(
+ maybe_add_timeout_to_query(
+ self._SELECT_EVENTS_FORMAT, metadata_request_timeout
+ ),
+ consistency_level=query_cl,
+ ),
(self.trace_id,),
time_spent,
- max_wait)
+ max_wait,
+ )
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
- self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
- for r in event_results)
+ self.events = tuple(
+ TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
+ for r in event_results
+ )
break
def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
- future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
+ future = self._session._create_response_future(
+ query, parameters, trace=False, custom_payload=None, timeout=timeout
+ )
# in case the user switched the row factory, set it to namedtuple for this query
future.row_factory = named_tuple_factory
future.send_request()
@@ -1079,12 +1389,22 @@ def _execute(self, query, parameters, time_spent, max_wait):
try:
return future.result()
except OperationTimedOut:
- raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
+ raise TraceUnavailable(
+ "Trace information was not available within %f seconds" % (max_wait,)
+ )
def __str__(self):
- return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \
- % (self.request_type, self.trace_id, self.coordinator, self.started_at,
- self.duration, self.parameters)
+ return (
+ "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s"
+ % (
+ self.request_type,
+ self.trace_id,
+ self.coordinator,
+ self.started_at,
+ self.duration,
+ self.parameters,
+ )
+ )
class TraceEvent(object):
@@ -1121,7 +1441,9 @@ class TraceEvent(object):
def __init__(self, description, timeuuid, source, source_elapsed, thread_name):
self.description = description
- self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc)
+ self.datetime = datetime.fromtimestamp(
+ unix_time_from_uuid1(timeuuid), tz=timezone.utc
+ )
self.source = source
if source_elapsed is not None:
self.source_elapsed = timedelta(microseconds=source_elapsed)
@@ -1130,7 +1452,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name):
self.thread_name = thread_name
def __str__(self):
- return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime)
+ return "%s on %s[%s] at %s" % (
+ self.description,
+ self.source,
+ self.thread_name,
+ self.datetime,
+ )
# TODO remove next major since we can target using the `host` attribute of session.execute
@@ -1139,9 +1466,12 @@ class HostTargetingStatement(object):
Wraps any query statement and attaches a target host, making
it usable in a targeted LBP without modifying the user's statement.
"""
+
def __init__(self, inner_statement, target_host):
- self.__class__ = type(inner_statement.__class__.__name__,
- (self.__class__, inner_statement.__class__),
- {})
- self.__dict__ = inner_statement.__dict__
- self.target_host = target_host
+ self.__class__ = type(
+ inner_statement.__class__.__name__,
+ (self.__class__, inner_statement.__class__),
+ {},
+ )
+ self.__dict__ = inner_statement.__dict__
+ self.target_host = target_host
diff --git a/cassandra/serializers.pxd b/cassandra/serializers.pxd
new file mode 100644
index 0000000000..60297077a8
--- /dev/null
+++ b/cassandra/serializers.pxd
@@ -0,0 +1,20 @@
+# Copyright ScyllaDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+cdef class Serializer:
+ # The cqltypes._CassandraType corresponding to this serializer
+ cdef object cqltype
+
+ cpdef bytes serialize(self, object value, int protocol_version)
diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx
new file mode 100644
index 0000000000..b02d0861cc
--- /dev/null
+++ b/cassandra/serializers.pyx
@@ -0,0 +1,398 @@
+# Copyright ScyllaDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Cython-optimized serializers for CQL types.
+
+Mirrors the architecture of deserializers.pyx. Currently implements
+optimized serialization for:
+- FloatType (4-byte big-endian float)
+- DoubleType (8-byte big-endian double)
+- Int32Type (4-byte big-endian signed int)
+- VectorType (type-specialized for float/double/int32, generic fallback)
+
+For all other types, GenericSerializer delegates to the Python-level
+cqltype.serialize() classmethod.
+"""
+
+from libc.stdint cimport int32_t
+from libc.string cimport memcpy
+from libc.float cimport FLT_MAX
+from libc.math cimport isinf, isnan
+from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING
+
+from cassandra import cqltypes
+import io
+from cassandra.marshal import uvint_pack
+
+cdef bint is_little_endian
+from cassandra.util import is_little_endian
+
+
+# ---------------------------------------------------------------------------
+# Base class
+# ---------------------------------------------------------------------------
+
+cdef class Serializer:
+ """Cython-based serializer class for a cqltype"""
+
+ def __init__(self, cqltype):
+ self.cqltype = cqltype
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ raise NotImplementedError
+
+
+# ---------------------------------------------------------------------------
+# Float range check
+# ---------------------------------------------------------------------------
+
+cdef inline void _check_float_range(double value) except *:
+ """Raise OverflowError for finite values outside float32 range.
+
+ Matches the behaviour of struct.pack('>f', value), which raises
+ OverflowError for values that cannot be represented as a 32-bit
+ IEEE 754 float. inf, -inf, and nan pass through unchanged.
+ """
+ if not isinf(value) and not isnan(value):
+ if value > FLT_MAX or value < -FLT_MAX:
+ raise OverflowError(
+ "Value %r too large for float32 (max %r)" % (value, FLT_MAX)
+ )
+
+
+# ---------------------------------------------------------------------------
+# Int32 range check
+# ---------------------------------------------------------------------------
+
+cdef inline void _check_int32_range(object value) except *:
+ """Raise OverflowError for values outside the signed int32 range.
+
+ Mirrors ``_check_float_range``: we intentionally raise OverflowError
+ (not struct.error) so callers only need to catch one exception type
+ for out-of-range values. The check must be done on the Python int
+ *before* the C-level cast, which would silently truncate.
+ """
+ if value > 2147483647 or value < -2147483648:
+ raise OverflowError(
+ "'i' format requires -2147483648 <= number <= 2147483647"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Scalar serializers
+# ---------------------------------------------------------------------------
+
+cdef class SerFloatType(Serializer):
+ """Serialize a Python float to 4-byte big-endian IEEE 754."""
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ _check_float_range(value)
+ cdef float val = value
+ cdef char out[4]
+ cdef char *src = &val
+
+ if is_little_endian:
+ out[0] = src[3]
+ out[1] = src[2]
+ out[2] = src[1]
+ out[3] = src[0]
+ else:
+ memcpy(out, src, 4)
+
+ return PyBytes_FromStringAndSize(out, 4)
+
+
+cdef class SerDoubleType(Serializer):
+ """Serialize a Python float to 8-byte big-endian IEEE 754."""
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ cdef double val = value
+ cdef char out[8]
+ cdef char *src = &val
+
+ if is_little_endian:
+ out[0] = src[7]
+ out[1] = src[6]
+ out[2] = src[5]
+ out[3] = src[4]
+ out[4] = src[3]
+ out[5] = src[2]
+ out[6] = src[1]
+ out[7] = src[0]
+ else:
+ memcpy(out, src, 8)
+
+ return PyBytes_FromStringAndSize(out, 8)
+
+
+cdef class SerInt32Type(Serializer):
+ """Serialize a Python int to 4-byte big-endian signed int32."""
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ _check_int32_range(value)
+ cdef int32_t val = value
+ cdef char out[4]
+ cdef char *src = &val
+
+ if is_little_endian:
+ out[0] = src[3]
+ out[1] = src[2]
+ out[2] = src[1]
+ out[3] = src[0]
+ else:
+ memcpy(out, src, 4)
+
+ return PyBytes_FromStringAndSize(out, 4)
+
+
+# ---------------------------------------------------------------------------
+# Type detection helpers
+# ---------------------------------------------------------------------------
+
+cdef inline bint _is_float_type(object subtype):
+ return subtype is cqltypes.FloatType or issubclass(subtype, cqltypes.FloatType)
+
+cdef inline bint _is_double_type(object subtype):
+ return subtype is cqltypes.DoubleType or issubclass(subtype, cqltypes.DoubleType)
+
+cdef inline bint _is_int32_type(object subtype):
+ return subtype is cqltypes.Int32Type or issubclass(subtype, cqltypes.Int32Type)
+
+
+# ---------------------------------------------------------------------------
+# VectorType serializer
+# ---------------------------------------------------------------------------
+
+cdef class SerVectorType(Serializer):
+ """
+ Optimized Cython serializer for VectorType.
+
+ For float, double, and int32 vectors, pre-allocates a contiguous buffer
+ and uses C-level byte swapping. For other subtypes, falls back to
+ per-element Python serialization.
+ """
+
+ cdef int vector_size
+ cdef object subtype
+ # 0 = generic, 1 = float, 2 = double, 3 = int32
+ cdef int type_code
+
+ def __init__(self, cqltype):
+ super().__init__(cqltype)
+ self.vector_size = cqltype.vector_size
+ self.subtype = cqltype.subtype
+
+ if _is_float_type(self.subtype):
+ self.type_code = 1
+ elif _is_double_type(self.subtype):
+ self.type_code = 2
+ elif _is_int32_type(self.subtype):
+ self.type_code = 3
+ else:
+ self.type_code = 0
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ cdef int v_length = len(value)
+ if v_length != self.vector_size:
+ raise ValueError(
+ "Expected sequence of size %d for vector of type %s and "
+ "dimension %d, observed sequence of length %d" % (
+ self.vector_size, self.subtype.typename,
+ self.vector_size, v_length))
+
+ if self.type_code == 1:
+ return self._serialize_float(value)
+ elif self.type_code == 2:
+ return self._serialize_double(value)
+ elif self.type_code == 3:
+ return self._serialize_int32(value)
+ else:
+ return self._serialize_generic(value, protocol_version)
+
+ cdef inline bytes _serialize_float(self, object values):
+ """Serialize a sequence of floats into a contiguous big-endian buffer.
+
+ Uses index-based access (values[i]) rather than iteration for
+ performance — the input must support ``__getitem__`` (list, tuple,
+ etc.). This is intentional: index access lets Cython emit a single
+ ``PyObject_GetItem`` call per element instead of iterator protocol
+ overhead.
+ """
+ cdef Py_ssize_t i
+ cdef Py_ssize_t buf_size = self.vector_size * 4
+ if buf_size == 0:
+ return b""
+
+ cdef object result = PyBytes_FromStringAndSize(NULL, buf_size)
+ cdef char *buf = PyBytes_AS_STRING(result)
+
+ cdef float val
+ cdef char *src
+ cdef char *dst
+
+ for i in range(self.vector_size):
+ _check_float_range(values[i])
+ val = values[i]
+ src = &val
+ dst = buf + i * 4
+
+ if is_little_endian:
+ dst[0] = src[3]
+ dst[1] = src[2]
+ dst[2] = src[1]
+ dst[3] = src[0]
+ else:
+ memcpy(dst, src, 4)
+
+ return result
+
+ cdef inline bytes _serialize_double(self, object values):
+ """Serialize a sequence of doubles into a contiguous big-endian buffer.
+
+ Uses index-based access (values[i]) rather than iteration for
+ performance — the input must support ``__getitem__`` (list, tuple,
+ etc.). This is intentional: index access lets Cython emit a single
+ ``PyObject_GetItem`` call per element instead of iterator protocol
+ overhead.
+ """
+ cdef Py_ssize_t i
+ cdef Py_ssize_t buf_size = self.vector_size * 8
+ if buf_size == 0:
+ return b""
+
+ cdef object result = PyBytes_FromStringAndSize(NULL, buf_size)
+ cdef char *buf = PyBytes_AS_STRING(result)
+
+ cdef double val
+ cdef char *src
+ cdef char *dst
+
+ for i in range(self.vector_size):
+ val = values[i]
+ src = &val
+ dst = buf + i * 8
+
+ if is_little_endian:
+ dst[0] = src[7]
+ dst[1] = src[6]
+ dst[2] = src[5]
+ dst[3] = src[4]
+ dst[4] = src[3]
+ dst[5] = src[2]
+ dst[6] = src[1]
+ dst[7] = src[0]
+ else:
+ memcpy(dst, src, 8)
+
+ return result
+
+ cdef inline bytes _serialize_int32(self, object values):
+ """Serialize a sequence of int32 values into a contiguous big-endian buffer.
+
+ Uses index-based access (values[i]) rather than iteration for
+ performance — the input must support ``__getitem__`` (list, tuple,
+ etc.). This is intentional: index access lets Cython emit a single
+ ``PyObject_GetItem`` call per element instead of iterator protocol
+ overhead.
+ """
+ cdef Py_ssize_t i
+ cdef Py_ssize_t buf_size = self.vector_size * 4
+ if buf_size == 0:
+ return b""
+
+ cdef object result = PyBytes_FromStringAndSize(NULL, buf_size)
+ cdef char *buf = PyBytes_AS_STRING(result)
+
+ cdef int32_t val
+ cdef char *src
+ cdef char *dst
+
+ for i in range(self.vector_size):
+ _check_int32_range(values[i])
+ val = values[i]
+ src = &val
+ dst = buf + i * 4
+
+ if is_little_endian:
+ dst[0] = src[3]
+ dst[1] = src[2]
+ dst[2] = src[1]
+ dst[3] = src[0]
+ else:
+ memcpy(dst, src, 4)
+
+ return result
+
+ cdef inline bytes _serialize_generic(self, object values, int protocol_version):
+ """Fallback: element-by-element Python serialization for non-optimized types."""
+ serialized_size = self.subtype.serial_size()
+ buf = io.BytesIO()
+ for item in values:
+ item_bytes = self.subtype.serialize(item, protocol_version)
+ if serialized_size is None:
+ buf.write(uvint_pack(len(item_bytes)))
+ buf.write(item_bytes)
+ return buf.getvalue()
+
+
+# ---------------------------------------------------------------------------
+# Generic serializer (fallback for all other types)
+# ---------------------------------------------------------------------------
+
+cdef class GenericSerializer(Serializer):
+ """
+ Wraps a generic cqltype for serialization, delegating to the Python-level
+ cqltype.serialize() classmethod.
+ """
+
+ cpdef bytes serialize(self, object value, int protocol_version):
+ return self.cqltype.serialize(value, protocol_version)
+
+ def __repr__(self):
+ return "GenericSerializer(%s)" % (self.cqltype,)
+
+
+# ---------------------------------------------------------------------------
+# Lookup and factory
+# ---------------------------------------------------------------------------
+
+cdef dict _ser_classes = {}
+
+cpdef Serializer find_serializer(cqltype):
+ """Find a serializer for a cqltype."""
+
+ # For VectorType, always use SerVectorType (it handles generic subtypes internally)
+ if issubclass(cqltype, cqltypes.VectorType):
+ return SerVectorType(cqltype)
+
+ # For scalar types with dedicated serializers, look up by name
+ name = 'Ser' + cqltype.__name__
+ cls = _ser_classes.get(name)
+ if cls is not None:
+ return cls(cqltype)
+
+ # Fallback to generic
+ return GenericSerializer(cqltype)
+
+
+def make_serializers(cqltypes_list):
+ """Create a list of Serializer objects for each given cqltype."""
+ return [find_serializer(ct) for ct in cqltypes_list]
+
+
+# Build the lookup dict for scalar serializers at module load time
+_ser_classes['SerFloatType'] = SerFloatType
+_ser_classes['SerDoubleType'] = SerDoubleType
+_ser_classes['SerInt32Type'] = SerInt32Type
diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py
index 5416ac461d..59c238169b 100644
--- a/tests/unit/test_parameter_binding.py
+++ b/tests/unit/test_parameter_binding.py
@@ -1,4 +1,4 @@
-# Copyright DataStax, Inc.
+# Copyright ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,18 @@
# limitations under the License.
import unittest
+import struct
import pytest
from cassandra.encoder import Encoder
from cassandra.protocol import ColumnMetadata
-from cassandra.query import (bind_params, ValueSequence, PreparedStatement,
- BoundStatement, UNSET_VALUE)
+from cassandra.query import (
+ bind_params,
+ ValueSequence,
+ PreparedStatement,
+ BoundStatement,
+ UNSET_VALUE,
+)
from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict
@@ -26,7 +32,6 @@
class ParamBindingTest(unittest.TestCase):
-
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0), Encoder())
assert result == "1 'a' 2.0"
@@ -48,18 +53,18 @@ def test_none_param(self):
assert result == "NULL"
def test_list_collection(self):
- result = bind_params("%s", (['a', 'b', 'c'],), Encoder())
+ result = bind_params("%s", (["a", "b", "c"],), Encoder())
assert result == "['a', 'b', 'c']"
def test_set_collection(self):
- result = bind_params("%s", (set(['a', 'b']),), Encoder())
+ result = bind_params("%s", (set(["a", "b"]),), Encoder())
assert result in ("{'a', 'b'}", "{'b', 'a'}")
def test_map_collection(self):
vals = OrderedDict()
- vals['a'] = 'a'
- vals['b'] = 'b'
- vals['c'] = 'c'
+ vals["a"] = "a"
+ vals["b"] = "b"
+ vals["c"] = "c"
result = bind_params("%s", (vals,), Encoder())
assert result == "{'a': 'a', 'b': 'b', 'c': 'c'}"
@@ -68,63 +73,70 @@ def test_quote_escaping(self):
assert result == """'''ef''''ef"ef""ef'''"""
def test_float_precision(self):
- f = 3.4028234663852886e+38
+ f = 3.4028234663852886e38
assert float(bind_params("%s", (f,), Encoder())) == f
-class BoundStatementTestV3(unittest.TestCase):
+class BoundStatementTestV3(unittest.TestCase):
protocol_version = 3
@classmethod
def setUpClass(cls):
- column_metadata = [ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type),
- ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type),
- ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type),
- ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)]
- cls.prepared = PreparedStatement(column_metadata=column_metadata,
- query_id=None,
- routing_key_indexes=[1, 0],
- query=None,
- keyspace='keyspace',
- protocol_version=cls.protocol_version, result_metadata=None,
- result_metadata_id=None)
+ column_metadata = [
+ ColumnMetadata("keyspace", "cf", "rk0", Int32Type),
+ ColumnMetadata("keyspace", "cf", "rk1", Int32Type),
+ ColumnMetadata("keyspace", "cf", "ck0", Int32Type),
+ ColumnMetadata("keyspace", "cf", "v0", Int32Type),
+ ]
+ cls.prepared = PreparedStatement(
+ column_metadata=column_metadata,
+ query_id=None,
+ routing_key_indexes=[1, 0],
+ query=None,
+ keyspace="keyspace",
+ protocol_version=cls.protocol_version,
+ result_metadata=None,
+ result_metadata_id=None,
+ )
cls.bound = BoundStatement(prepared_statement=cls.prepared)
def test_invalid_argument_type(self):
- values = (0, 0, 0, 'string not int')
+ values = (0, 0, 0, "string not int")
with pytest.raises(TypeError) as exc:
self.bound.bind(values)
e = exc.value
- assert 'v0' in str(e)
- assert 'Int32Type' in str(e)
- assert 'str' in str(e)
+ assert "v0" in str(e)
+ assert "Int32Type" in str(e)
+ assert "str" in str(e)
- values = (['1', '2'], 0, 0, 0)
+ values = (["1", "2"], 0, 0, 0)
with pytest.raises(TypeError) as exc:
self.bound.bind(values)
e = exc.value
- assert 'rk0' in str(e)
- assert 'Int32Type' in str(e)
- assert 'list' in str(e)
+ assert "rk0" in str(e)
+ assert "Int32Type" in str(e)
+ assert "list" in str(e)
def test_inherit_fetch_size(self):
- keyspace = 'keyspace1'
- column_family = 'cf1'
+ keyspace = "keyspace1"
+ column_family = "cf1"
column_metadata = [
- ColumnMetadata(keyspace, column_family, 'foo1', Int32Type),
- ColumnMetadata(keyspace, column_family, 'foo2', Int32Type)
+ ColumnMetadata(keyspace, column_family, "foo1", Int32Type),
+ ColumnMetadata(keyspace, column_family, "foo2", Int32Type),
]
- prepared_statement = PreparedStatement(column_metadata=column_metadata,
- query_id=None,
- routing_key_indexes=[],
- query=None,
- keyspace=keyspace,
- protocol_version=self.protocol_version,
- result_metadata=None,
- result_metadata_id=None)
+ prepared_statement = PreparedStatement(
+ column_metadata=column_metadata,
+ query_id=None,
+ routing_key_indexes=[],
+ query=None,
+ keyspace=keyspace,
+ protocol_version=self.protocol_version,
+ result_metadata=None,
+ result_metadata_id=None,
+ )
prepared_statement.fetch_size = 1234
bound_statement = BoundStatement(prepared_statement=prepared_statement)
assert 1234 == bound_statement.fetch_size
@@ -134,21 +146,23 @@ def test_too_few_parameters_for_routing_key(self):
self.prepared.bind((1,))
bound = self.prepared.bind((1, 2))
- assert bound.keyspace == 'keyspace'
+ assert bound.keyspace == "keyspace"
def test_dict_missing_routing_key(self):
with pytest.raises(KeyError):
- self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0})
+ self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0})
with pytest.raises(KeyError):
- self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0})
+ self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0})
def test_missing_value(self):
with pytest.raises(KeyError):
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
+ self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0})
def test_extra_value(self):
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict
- assert self.bound.values == [b'\x00' * 4] * 4 # four encoded zeros
+ self.bound.bind(
+ {"rk0": 0, "rk1": 0, "ck0": 0, "v0": 0, "should_not_be_here": 123}
+ ) # okay to have extra keys in dict
+ assert self.bound.values == [b"\x00" * 4] * 4 # four encoded zeros
with pytest.raises(ValueError):
self.bound.bind((0, 0, 0, 0, 123))
@@ -158,19 +172,21 @@ def test_values_none(self):
self.bound.bind(None)
# prepared statement with no values
- prepared_statement = PreparedStatement(column_metadata=[],
- query_id=None,
- routing_key_indexes=[],
- query=None,
- keyspace='whatever',
- protocol_version=self.protocol_version,
- result_metadata=None,
- result_metadata_id=None)
+ prepared_statement = PreparedStatement(
+ column_metadata=[],
+ query_id=None,
+ routing_key_indexes=[],
+ query=None,
+ keyspace="whatever",
+ protocol_version=self.protocol_version,
+ result_metadata=None,
+ result_metadata_id=None,
+ )
bound = prepared_statement.bind(None)
assertListEqual(bound.values, [])
def test_bind_none(self):
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None})
+ self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": None})
assert self.bound.values[-1] == None
old_values = self.bound.values
@@ -180,7 +196,7 @@ def test_bind_none(self):
def test_unset_value(self):
with pytest.raises(ValueError):
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
+ self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE})
with pytest.raises(ValueError):
self.bound.bind((0, 0, 0, UNSET_VALUE))
@@ -192,13 +208,13 @@ def test_dict_missing_routing_key(self):
# in v4 it implicitly binds UNSET_VALUE for missing items,
# UNSET_VALUE is ValueError for routing keys
with pytest.raises(ValueError):
- self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0})
+ self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0})
with pytest.raises(ValueError):
- self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0})
+ self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0})
def test_missing_value(self):
# in v4 missing values are UNSET_VALUE
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0})
+ self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0})
assert self.bound.values[-1] == UNSET_VALUE
old_values = self.bound.values
@@ -207,7 +223,7 @@ def test_missing_value(self):
assert self.bound.values[-1] == UNSET_VALUE
def test_unset_value(self):
- self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE})
+ self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE})
assert self.bound.values[-1] == UNSET_VALUE
self.bound.bind((0, 0, 0, UNSET_VALUE))
@@ -216,3 +232,286 @@ def test_unset_value(self):
class BoundStatementTestV5(BoundStatementTestV4):
protocol_version = 5
+
+
+class StubSerializer:
+ """Stub that mimics a Cython Serializer object for testing the fast path."""
+
+ def __init__(self, cqltype):
+ self.cqltype = cqltype
+
+ def serialize(self, value, protocol_version):
+ return self.cqltype.serialize(value, protocol_version)
+
+
+class OverflowSerializer:
+ """Stub that raises struct.error, mimicking Cython int32 range check."""
+
+ def __init__(self, cqltype):
+ self.cqltype = cqltype
+
+ def serialize(self, value, protocol_version):
+ raise struct.error("'i' format requires -2147483648 <= number <= 2147483647")
+
+
+class CythonBindPathTest(unittest.TestCase):
+ """Tests for the Cython serializer fast path in BoundStatement.bind().
+
+ These tests inject stub serializers via the PreparedStatement's cached
+ _cached_serializers attribute to exercise the Cython bind branch without
+ requiring compiled Cython.
+ """
+
+ protocol_version = 4
+
+ def _make_prepared(self, column_metadata, serializers=None):
+ """Create a PreparedStatement and inject serializers into its cache."""
+ prepared = PreparedStatement(
+ column_metadata=column_metadata,
+ query_id=None,
+ routing_key_indexes=[],
+ query=None,
+ keyspace="keyspace",
+ protocol_version=self.protocol_version,
+ result_metadata=None,
+ result_metadata_id=None,
+ )
+ # Inject directly into the cache attribute used by the _serializers
+ # property, bypassing the lazy initialization.
+ prepared._cached_serializers = serializers
+ return prepared
+
+ def test_cython_path_normal_serialization(self):
+ """Cython fast path produces the same result as the plain Python path."""
+ column_metadata = [
+ ColumnMetadata("keyspace", "cf", "c0", Int32Type),
+ ColumnMetadata("keyspace", "cf", "c1", Int32Type),
+ ]
+ serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)]
+ prepared = self._make_prepared(column_metadata, serializers)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((42, -1))
+ assert bound.values == [
+ Int32Type.serialize(42, self.protocol_version),
+ Int32Type.serialize(-1, self.protocol_version),
+ ]
+
+ def test_cython_path_none_value(self):
+ """None values pass through the Cython path without serialization."""
+ column_metadata = [ColumnMetadata("keyspace", "cf", "c0", Int32Type)]
+ serializers = [StubSerializer(Int32Type)]
+ prepared = self._make_prepared(column_metadata, serializers)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((None,))
+ assert bound.values == [None]
+
+ def test_cython_path_unset_value(self):
+ """UNSET_VALUE is handled correctly in the Cython fast path (v4+)."""
+ column_metadata = [
+ ColumnMetadata("keyspace", "cf", "c0", Int32Type),
+ ColumnMetadata("keyspace", "cf", "c1", Int32Type),
+ ]
+ serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)]
+ prepared = self._make_prepared(column_metadata, serializers)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((42, UNSET_VALUE))
+ assert bound.values[0] == Int32Type.serialize(42, self.protocol_version)
+ assert bound.values[1] == UNSET_VALUE
+
+ def test_cython_path_overflow_error_wrapped(self):
+ """struct.error from Cython int32 range check is caught and wrapped with column context."""
+ column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)]
+ serializers = [OverflowSerializer(Int32Type)]
+ prepared = self._make_prepared(column_metadata, serializers)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ with pytest.raises(TypeError) as exc:
+ bound.bind((2**31,))
+ msg = str(exc.value)
+ assert "v0" in msg
+ assert "Int32Type" in msg
+ assert "int" in msg
+
+ def test_cython_path_type_error_wrapped(self):
+ """TypeError from serializer is caught and wrapped with column context."""
+ column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)]
+ serializers = [StubSerializer(Int32Type)]
+ prepared = self._make_prepared(column_metadata, serializers)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ with pytest.raises(TypeError) as exc:
+ bound.bind(("not_an_int",))
+ msg = str(exc.value)
+ assert "v0" in msg
+ assert "Int32Type" in msg
+
+ def test_plain_path_overflow_error_wrapped(self):
+ """Out-of-range int in the plain Python path raises OverflowError (from
+ the Cython serializer) or struct.error (from the pure-Python fallback)
+ and is wrapped with column context."""
+ column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)]
+ # Force the plain Python path (no Cython serializers)
+ prepared = self._make_prepared(column_metadata, serializers=None)
+
+ bound = BoundStatement(prepared_statement=prepared)
+ with pytest.raises(TypeError) as exc:
+ bound.bind((2**31,))
+ msg = str(exc.value)
+ assert "v0" in msg
+ assert "Int32Type" in msg
+
+
+class UnsetValueBindingTest(unittest.TestCase):
+ """Tests for UNSET_VALUE handling in all bind paths.
+
+ These specifically test UNSET_VALUE in non-trailing positions to catch
+ index-management bugs in the pre-allocated values list.
+ """
+
+ protocol_version = 4
+
+ def _make_prepared(
+ self, column_metadata, serializers=None, routing_key_indexes=None
+ ):
+ prepared = PreparedStatement(
+ column_metadata=column_metadata,
+ query_id=None,
+ routing_key_indexes=routing_key_indexes or [],
+ query=None,
+ keyspace="keyspace",
+ protocol_version=self.protocol_version,
+ result_metadata=None,
+ result_metadata_id=None,
+ )
+ prepared._cached_serializers = serializers
+ return prepared
+
+ def _three_column_metadata(self):
+ return [
+ ColumnMetadata("keyspace", "cf", "c0", Int32Type),
+ ColumnMetadata("keyspace", "cf", "c1", Int32Type),
+ ColumnMetadata("keyspace", "cf", "c2", Int32Type),
+ ]
+
+ # --- Plain Python path (no serializers) ---
+
+ def test_plain_unset_mid_list(self):
+ """UNSET_VALUE in the middle of a value list does not corrupt indices."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((0, UNSET_VALUE, 2))
+ assert bound.values == [
+ Int32Type.serialize(0, self.protocol_version),
+ UNSET_VALUE,
+ Int32Type.serialize(2, self.protocol_version),
+ ]
+
+ def test_plain_unset_first(self):
+ """UNSET_VALUE as the first value does not corrupt indices."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((UNSET_VALUE, 1, 2))
+ assert bound.values == [
+ UNSET_VALUE,
+ Int32Type.serialize(1, self.protocol_version),
+ Int32Type.serialize(2, self.protocol_version),
+ ]
+
+ def test_plain_all_unset(self):
+ """All values are UNSET_VALUE."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE))
+ assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE]
+
+ def test_plain_unset_trailing(self):
+ """UNSET_VALUE as the last explicit value."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((0, 1, UNSET_VALUE))
+ assert bound.values == [
+ Int32Type.serialize(0, self.protocol_version),
+ Int32Type.serialize(1, self.protocol_version),
+ UNSET_VALUE,
+ ]
+
+ def test_plain_implicit_unset_fill(self):
+ """Fewer values than columns fills remaining with implicit UNSET_VALUE."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((0,))
+ assert bound.values == [
+ Int32Type.serialize(0, self.protocol_version),
+ UNSET_VALUE,
+ UNSET_VALUE,
+ ]
+
+ def test_plain_mixed_none_and_unset(self):
+ """Mix of None and UNSET_VALUE in the same bind."""
+ col_meta = self._three_column_metadata()
+ prepared = self._make_prepared(col_meta, serializers=None)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((None, UNSET_VALUE, 2))
+ assert bound.values == [
+ None,
+ UNSET_VALUE,
+ Int32Type.serialize(2, self.protocol_version),
+ ]
+
+ # --- Cython serializer path (with stub serializers) ---
+
+ def test_cython_unset_mid_list(self):
+ """UNSET_VALUE in the middle with Cython serializers does not corrupt indices."""
+ col_meta = self._three_column_metadata()
+ serializers = [StubSerializer(Int32Type)] * 3
+ prepared = self._make_prepared(col_meta, serializers=serializers)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((0, UNSET_VALUE, 2))
+ assert bound.values == [
+ Int32Type.serialize(0, self.protocol_version),
+ UNSET_VALUE,
+ Int32Type.serialize(2, self.protocol_version),
+ ]
+
+ def test_cython_unset_first(self):
+ """UNSET_VALUE as the first value with Cython serializers."""
+ col_meta = self._three_column_metadata()
+ serializers = [StubSerializer(Int32Type)] * 3
+ prepared = self._make_prepared(col_meta, serializers=serializers)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((UNSET_VALUE, 1, 2))
+ assert bound.values == [
+ UNSET_VALUE,
+ Int32Type.serialize(1, self.protocol_version),
+ Int32Type.serialize(2, self.protocol_version),
+ ]
+
+ def test_cython_all_unset(self):
+ """All values are UNSET_VALUE with Cython serializers."""
+ col_meta = self._three_column_metadata()
+ serializers = [StubSerializer(Int32Type)] * 3
+ prepared = self._make_prepared(col_meta, serializers=serializers)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE))
+ assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE]
+
+ def test_cython_mixed_none_and_unset(self):
+ """Mix of None and UNSET_VALUE with Cython serializers."""
+ col_meta = self._three_column_metadata()
+ serializers = [StubSerializer(Int32Type)] * 3
+ prepared = self._make_prepared(col_meta, serializers=serializers)
+ bound = BoundStatement(prepared_statement=prepared)
+ bound.bind((None, UNSET_VALUE, 2))
+ assert bound.values == [
+ None,
+ UNSET_VALUE,
+ Int32Type.serialize(2, self.protocol_version),
+ ]