diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..af9ceff0eb 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -33,7 +33,15 @@ from cassandra.protocol import _UNSET_VALUE from cassandra.util import OrderedDict, _sanitize_identifiers +try: + from cassandra.serializers import make_serializers as _cython_make_serializers + + _HAVE_CYTHON_SERIALIZERS = True +except ImportError: + _HAVE_CYTHON_SERIALIZERS = False + import logging + log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE @@ -49,9 +57,9 @@ Only valid when using native protocol v4+ """ -NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') -START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') -END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") _clean_name_cache = {} @@ -60,7 +68,9 @@ def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: - clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) _clean_name_cache[name] = clean return clean @@ -83,6 +93,7 @@ def tuple_factory(colnames, rows): """ return rows + class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an @@ -90,6 +101,7 @@ class PseudoNamedTupleRow(object): but otherwise do not attempt to implement the full namedtuple or iterable interface. """ + def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) @@ -104,8 +116,7 @@ def __iter__(self): return iter(self._tuple) def __repr__(self): - return '{t}({od})'.format(t=self.__class__.__name__, - od=self._dict) + return "{t}({od})".format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): @@ -113,8 +124,7 @@ def pseudo_namedtuple_factory(colnames, rows): Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ - return [PseudoNamedTupleRow(od) - for od in ordered_dict_factory(colnames, rows)] + return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] def named_tuple_factory(colnames, rows): @@ -148,7 +158,7 @@ def named_tuple_factory(colnames, rows): """ clean_column_names = map(_clean_column_name, colnames) try: - Row = namedtuple('Row', clean_column_names) + Row = namedtuple("Row", clean_column_names) except SyntaxError: warnings.warn( "Failed creating namedtuple for a result because there were too " @@ -159,19 +169,23 @@ def named_tuple_factory(colnames, rows): "values on row objects, Upgrade to Python 3.7, or use a different " "row factory. (column names: {colnames})".format( substitute_factory_name=pseudo_namedtuple_factory.__name__, - colnames=colnames + colnames=colnames, ) ) return pseudo_namedtuple_factory(colnames, rows) except Exception: - clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt - log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " - "(see Python 'namedtuple' documentation for details on name rules). " - "Results will be returned with positional names. " - "Avoid this by choosing different names, using SELECT \"\" AS aliases, " - "or specifying a different row_factory on your Session" % - (colnames, clean_column_names)) - Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + clean_column_names = list( + map(_clean_column_name, colnames) + ) # create list because py3 map object will be consumed by first attempt + log.warning( + "Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + 'Avoid this by choosing different names, using SELECT "" AS aliases, ' + "or specifying a different row_factory on your Session" + % (colnames, clean_column_names) + ) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] @@ -276,11 +290,24 @@ class Statement(object): _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False, table=None): - if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors - raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + def __init__( + self, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + table=None, + ): + if retry_policy and not hasattr( + retry_policy, "on_read_timeout" + ): # just checking one method to detect positional parameter errors + raise ValueError( + "retry_policy should implement cassandra.policies.RetryPolicy" + ) if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: @@ -329,17 +356,20 @@ def _del_routing_key(self): If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. - """) + """, + ) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL") + "or ConsistencyLevel.LOCAL_SERIAL" + ) self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): @@ -384,7 +414,8 @@ def is_lwt(self): conditional statements. .. versionadded:: 2.0.0 - """) + """, + ) class SimpleStatement(Statement): @@ -392,9 +423,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None, is_idempotent=False): + def __init__( + self, + query_string, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + ): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -402,8 +442,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + is_idempotent, + ) self._query_string = query_string @property @@ -411,9 +460,14 @@ def query_string(self): return self._query_string def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -442,7 +496,7 @@ class PreparedStatement(object): A note about * in prepared statements """ - column_metadata = None #TODO: make this bind_metadata in next major + column_metadata = None # TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None @@ -459,9 +513,19 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, - is_lwt=False, column_encryption_policy=None): + def __init__( + self, + column_metadata, + query_id, + routing_key_indexes, + query, + keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt=False, + column_encryption_policy=None, + ): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -474,14 +538,61 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self.is_idempotent = False self._is_lwt = is_lwt + @property + def _serializers(self): + """Lazily create and cache Cython serializers for column types. + + Returns a list of Serializer objects if Cython serializers are available + and there is no column encryption policy, otherwise returns None. + + The column_encryption_policy check is performed on every access (not + cached) so that serializers are correctly bypassed if a policy is set + after construction. This means the cache never goes stale: once a CE + policy is present, we always return None and fall through to the + encryption-aware bind path. + """ + if self.column_encryption_policy: + return None + try: + return self._cached_serializers + except AttributeError: + pass + if _HAVE_CYTHON_SERIALIZERS and self.column_metadata: + self._cached_serializers = _cython_make_serializers( + [col.type for col in self.column_metadata] + ) + else: + self._cached_serializers = None + return self._cached_serializers + @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy=None): + def from_message( + cls, + query_id, + column_metadata, + pk_indexes, + cluster_metadata, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy=None, + ): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + return PreparedStatement( + column_metadata, + query_id, + None, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) if pk_indexes: routing_key_indexes = pk_indexes @@ -496,18 +607,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) + statement_indexes = dict( + (c.name, i) for i, c in enumerate(column_metadata) + ) # a list of which indexes in the statement correspond to partition key items try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: # we're missing a partition key component in the prepared - pass # statement; just leave routing_key_indexes as None - - return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + routing_key_indexes = [ + statement_indexes[c.name] for c in partition_key_columns + ] + except ( + KeyError + ): # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement( + column_metadata, + query_id, + routing_key_indexes, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) def bind(self, values): """ @@ -519,19 +644,54 @@ def bind(self, values): def is_routing_key_index(self, i): if self._routing_key_index_set is None: - self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + self._routing_key_index_set = ( + set(self.routing_key_indexes) if self.routing_key_indexes else set() + ) return i in self._routing_key_index_set def is_lwt(self): return self._is_lwt def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ +def _raise_bind_serialize_error(col_spec, value, exc): + """Wrap TypeError, struct.error, or OverflowError with column context. + + Called from all three bind loop paths (CE, Cython, plain Python) to + provide a uniform error message that includes the column name and + expected type. The message distinguishes between wrong-type and + out-of-range scenarios: + + - TypeError -> "invalid type" + - OverflowError -> "value out of range" + - struct.error -> "value out of range" + + Other exception types (e.g. ValueError from VectorType dimension + mismatch) propagate without wrapping. + """ + actual_type = type(value) + if isinstance(exc, (OverflowError, struct.error)): + reason = "value out of range" + else: + reason = "invalid type" + message = ( + 'Received an argument with %s for column "%s". ' + "Expected: %s, Got: %s; (%s)" + % (reason, col_spec.name, col_spec.type, actual_type, exc) + ) + raise TypeError(message) from exc + + class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. @@ -548,9 +708,17 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + def __init__( + self, + prepared_statement, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + ): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. @@ -571,9 +739,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None self.keyspace = meta[0].keyspace_name self.table = meta[0].table_name - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, - prepared_statement.is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + prepared_statement.is_idempotent, + ) def bind(self, values): """ @@ -615,64 +791,132 @@ def bind(self, values): values.append(UNSET_VALUE) else: raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + "Column name `%s` not found in bound dict." % (col.name) + ) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( - "Too many arguments provided to bind() (got %d, expected %d)" % - (len(values), len(col_meta))) + "Too many arguments provided to bind() (got %d, expected %d)" + % (len(values), len(col_meta)) + ) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): + if ( + proto_version < 4 + and self.prepared_statement.routing_key_indexes + and value_len < len(self.prepared_statement.routing_key_indexes) + ): raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) + "Too few arguments provided to bind() (got %d, required %d for routing key)" + % (value_len, len(self.prepared_statement.routing_key_indexes)) + ) self.raw_values = values - self.values = [] - for value, col_spec in zip(values, col_meta): - if value is None: - self.values.append(None) - elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() + # Pre-allocate to avoid repeated list growth reallocations + self.values = [None] * col_meta_len + idx = 0 + if ce_policy: + # Column encryption path: check each column for CE policy + for value, col_spec in zip(values, col_meta): + if value is None: + self.values[idx] = None + elif value is UNSET_VALUE: + if proto_version >= 4: + idx = self._append_unset_value(idx) + continue + else: + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + try: + col_desc = ColDesc( + col_spec.keyspace_name, col_spec.table_name, col_spec.name + ) + uses_ce = ce_policy.contains_column(col_desc) + if uses_ce: + col_type = ce_policy.column_type(col_desc) + col_bytes = col_type.serialize(value, proto_version) + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + else: + col_bytes = col_spec.type.serialize(value, proto_version) + self.values[idx] = col_bytes + # struct.error: int32 out-of-range; OverflowError: float out-of-range + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 + else: + # Fast path: no column encryption, use Cython serializers if available + serializers = self.prepared_statement._serializers + if serializers is not None: + for ser, value, col_spec in zip(serializers, values, col_meta): + if value is None: + self.values[idx] = None + elif value is UNSET_VALUE: + if proto_version >= 4: + idx = self._append_unset_value(idx) + continue + else: + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) + else: + try: + col_bytes = ser.serialize(value, proto_version) + self.values[idx] = col_bytes + # struct.error: int32 out-of-range; OverflowError: float out-of-range + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 else: - try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) - uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) - if uses_ce: - col_bytes = ce_policy.encrypt(col_desc, col_bytes) - self.values.append(col_bytes) - except (TypeError, struct.error) as exc: - actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) - raise TypeError(message) + for value, col_spec in zip(values, col_meta): + if value is None: + self.values[idx] = None + elif value is UNSET_VALUE: + if proto_version >= 4: + idx = self._append_unset_value(idx) + continue + else: + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) + else: + try: + col_bytes = col_spec.type.serialize(value, proto_version) + self.values[idx] = col_bytes + # struct.error: int32 out-of-range; OverflowError: float out-of-range + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 if proto_version >= 4: - diff = col_meta_len - len(self.values) - if diff: - for _ in range(diff): - self._append_unset_value() + # Fill remaining unbound columns with UNSET_VALUE (v4+ feature). + # The pre-allocated list already has slots for these, so index + # assignment works directly without trimming first. + while idx < col_meta_len: + idx = self._append_unset_value(idx) + elif idx < col_meta_len: + # Pre-v4: trim trailing unused slots (no UNSET_VALUE support) + self.values = self.values[:idx] return self - def _append_unset_value(self): - next_index = len(self.values) - if self.prepared_statement.is_routing_key_index(next_index): - col_meta = self.prepared_statement.column_metadata[next_index] - raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) - self.values.append(UNSET_VALUE) + def _append_unset_value(self, idx): + if self.prepared_statement.is_routing_key_index(idx): + col_meta = self.prepared_statement.column_metadata[idx] + raise ValueError( + "Cannot bind UNSET_VALUE as a part of the routing key '%s'" + % col_meta.name + ) + self.values[idx] = UNSET_VALUE + return idx + 1 @property def routing_key(self): @@ -686,7 +930,9 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) + self._routing_key = b"".join( + self._key_parts_packed(self.values[i] for i in routing_indexes) + ) return self._routing_key @@ -694,9 +940,15 @@ def is_lwt(self): return self.prepared_statement.is_lwt() def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.prepared_statement.query_string, self.raw_values, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.prepared_statement.query_string, + self.raw_values, + consistency, + ) + __repr__ = __str__ @@ -731,7 +983,7 @@ def __str__(self): return self.name def __repr__(self): - return "BatchType.%s" % (self.name, ) + return "BatchType.%s" % (self.name,) BatchType.LOGGED = BatchType("LOGGED", 0) @@ -763,9 +1015,15 @@ class BatchStatement(Statement): _session = None _is_lwt = False - def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, - session=None, custom_payload=None): + def __init__( + self, + batch_type=BatchType.LOGGED, + retry_policy=None, + consistency_level=None, + serial_consistency_level=None, + session=None, + custom_payload=None, + ): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -813,8 +1071,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, self.batch_type = batch_type self._statements_and_parameters = [] self._session = session - Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + Statement.__init__( + self, + retry_policy=retry_policy, + consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, + custom_payload=custom_payload, + ) def clear(self): """ @@ -853,11 +1116,14 @@ def add(self, statement, parameters=None): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " - "to BatchStatement.add()") + "to BatchStatement.add()" + ) self._update_state(statement) if statement.is_lwt(): self._is_lwt = True - self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + self._add_statement_and_params( + True, statement.prepared_statement.query_id, statement.values + ) else: # it must be a SimpleStatement query_string = statement.query_string @@ -881,7 +1147,9 @@ def add_all(self, statements, parameters): def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: - raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + raise ValueError( + "Batch statement cannot contain more than %d statements." % 0xFFFF + ) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): @@ -907,9 +1175,15 @@ def __len__(self): return len(self._statements_and_parameters) def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.batch_type, len(self), consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return "" % ( + self.batch_type, + len(self), + consistency, + ) + __repr__ = __str__ @@ -931,7 +1205,9 @@ def __str__(self): def bind_params(query, params, encoder): if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + return query % dict( + (k, encoder.cql_encode_all_types(v)) for k, v in params.items() + ) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -940,6 +1216,7 @@ class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ + pass @@ -1000,7 +1277,9 @@ class QueryTrace(object): _session = None - _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_SESSIONS_FORMAT = ( + "SELECT * FROM system_traces.sessions WHERE session_id = %s" + ) _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 @@ -1029,18 +1308,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( - "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." + % (max_wait,) + ) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) - metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout + metadata_request_timeout = ( + self._session.cluster.control_connection + and self._session.cluster.control_connection._metadata_request_timeout + ) session_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_SESSIONS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), + (self.trace_id,), + time_spent, + max_wait, + ) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries session_row = session_results.one() if session_results else None - is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + is_complete = ( + session_row is not None + and session_row.duration is not None + and session_row.started_at is not None + ) if not session_results or (wait_for_complete and not is_complete): - time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + time.sleep(self._BASE_RETRY_SLEEP * (2**attempt)) attempt += 1 continue if is_complete: @@ -1049,29 +1346,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) if is_complete else None + self.duration = ( + timedelta(microseconds=session_row.duration) if is_complete else None + ) self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 - self.client = getattr(session_row, 'client', None) + self.client = getattr(session_row, "client", None) - log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + log.debug( + "Attempting to fetch trace events for trace ID: %s", self.trace_id + ) time_spent = time.time() - start event_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout), - consistency_level=query_cl), + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_EVENTS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), (self.trace_id,), time_spent, - max_wait) + max_wait, + ) log.debug("Fetched trace events for trace ID: %s", self.trace_id) - self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) - for r in event_results) + self.events = tuple( + TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results + ) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future( + query, parameters, trace=False, custom_payload=None, timeout=timeout + ) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() @@ -1079,12 +1389,22 @@ def _execute(self, query, parameters, time_spent, max_wait): try: return future.result() except OperationTimedOut: - raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + raise TraceUnavailable( + "Trace information was not available within %f seconds" % (max_wait,) + ) def __str__(self): - return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ - % (self.request_type, self.trace_id, self.coordinator, self.started_at, - self.duration, self.parameters) + return ( + "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" + % ( + self.request_type, + self.trace_id, + self.coordinator, + self.started_at, + self.duration, + self.parameters, + ) + ) class TraceEvent(object): @@ -1121,7 +1441,9 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) + self.datetime = datetime.fromtimestamp( + unix_time_from_uuid1(timeuuid), tz=timezone.utc + ) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -1130,7 +1452,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.thread_name = thread_name def __str__(self): - return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + return "%s on %s[%s] at %s" % ( + self.description, + self.source, + self.thread_name, + self.datetime, + ) # TODO remove next major since we can target using the `host` attribute of session.execute @@ -1139,9 +1466,12 @@ class HostTargetingStatement(object): Wraps any query statement and attaches a target host, making it usable in a targeted LBP without modifying the user's statement. """ + def __init__(self, inner_statement, target_host): - self.__class__ = type(inner_statement.__class__.__name__, - (self.__class__, inner_statement.__class__), - {}) - self.__dict__ = inner_statement.__dict__ - self.target_host = target_host + self.__class__ = type( + inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}, + ) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host diff --git a/cassandra/serializers.pxd b/cassandra/serializers.pxd new file mode 100644 index 0000000000..60297077a8 --- /dev/null +++ b/cassandra/serializers.pxd @@ -0,0 +1,20 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef class Serializer: + # The cqltypes._CassandraType corresponding to this serializer + cdef object cqltype + + cpdef bytes serialize(self, object value, int protocol_version) diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx new file mode 100644 index 0000000000..b02d0861cc --- /dev/null +++ b/cassandra/serializers.pyx @@ -0,0 +1,398 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cython-optimized serializers for CQL types. + +Mirrors the architecture of deserializers.pyx. Currently implements +optimized serialization for: +- FloatType (4-byte big-endian float) +- DoubleType (8-byte big-endian double) +- Int32Type (4-byte big-endian signed int) +- VectorType (type-specialized for float/double/int32, generic fallback) + +For all other types, GenericSerializer delegates to the Python-level +cqltype.serialize() classmethod. +""" + +from libc.stdint cimport int32_t +from libc.string cimport memcpy +from libc.float cimport FLT_MAX +from libc.math cimport isinf, isnan +from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING + +from cassandra import cqltypes +import io +from cassandra.marshal import uvint_pack + +cdef bint is_little_endian +from cassandra.util import is_little_endian + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + +cdef class Serializer: + """Cython-based serializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + cpdef bytes serialize(self, object value, int protocol_version): + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Float range check +# --------------------------------------------------------------------------- + +cdef inline void _check_float_range(double value) except *: + """Raise OverflowError for finite values outside float32 range. + + Matches the behaviour of struct.pack('>f', value), which raises + OverflowError for values that cannot be represented as a 32-bit + IEEE 754 float. inf, -inf, and nan pass through unchanged. + """ + if not isinf(value) and not isnan(value): + if value > FLT_MAX or value < -FLT_MAX: + raise OverflowError( + "Value %r too large for float32 (max %r)" % (value, FLT_MAX) + ) + + +# --------------------------------------------------------------------------- +# Int32 range check +# --------------------------------------------------------------------------- + +cdef inline void _check_int32_range(object value) except *: + """Raise OverflowError for values outside the signed int32 range. + + Mirrors ``_check_float_range``: we intentionally raise OverflowError + (not struct.error) so callers only need to catch one exception type + for out-of-range values. The check must be done on the Python int + *before* the C-level cast, which would silently truncate. + """ + if value > 2147483647 or value < -2147483648: + raise OverflowError( + "'i' format requires -2147483648 <= number <= 2147483647" + ) + + +# --------------------------------------------------------------------------- +# Scalar serializers +# --------------------------------------------------------------------------- + +cdef class SerFloatType(Serializer): + """Serialize a Python float to 4-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_float_range(value) + cdef float val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +cdef class SerDoubleType(Serializer): + """Serialize a Python float to 8-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + cdef double val = value + cdef char out[8] + cdef char *src = &val + + if is_little_endian: + out[0] = src[7] + out[1] = src[6] + out[2] = src[5] + out[3] = src[4] + out[4] = src[3] + out[5] = src[2] + out[6] = src[1] + out[7] = src[0] + else: + memcpy(out, src, 8) + + return PyBytes_FromStringAndSize(out, 8) + + +cdef class SerInt32Type(Serializer): + """Serialize a Python int to 4-byte big-endian signed int32.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_int32_range(value) + cdef int32_t val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +# --------------------------------------------------------------------------- +# Type detection helpers +# --------------------------------------------------------------------------- + +cdef inline bint _is_float_type(object subtype): + return subtype is cqltypes.FloatType or issubclass(subtype, cqltypes.FloatType) + +cdef inline bint _is_double_type(object subtype): + return subtype is cqltypes.DoubleType or issubclass(subtype, cqltypes.DoubleType) + +cdef inline bint _is_int32_type(object subtype): + return subtype is cqltypes.Int32Type or issubclass(subtype, cqltypes.Int32Type) + + +# --------------------------------------------------------------------------- +# VectorType serializer +# --------------------------------------------------------------------------- + +cdef class SerVectorType(Serializer): + """ + Optimized Cython serializer for VectorType. + + For float, double, and int32 vectors, pre-allocates a contiguous buffer + and uses C-level byte swapping. For other subtypes, falls back to + per-element Python serialization. + """ + + cdef int vector_size + cdef object subtype + # 0 = generic, 1 = float, 2 = double, 3 = int32 + cdef int type_code + + def __init__(self, cqltype): + super().__init__(cqltype) + self.vector_size = cqltype.vector_size + self.subtype = cqltype.subtype + + if _is_float_type(self.subtype): + self.type_code = 1 + elif _is_double_type(self.subtype): + self.type_code = 2 + elif _is_int32_type(self.subtype): + self.type_code = 3 + else: + self.type_code = 0 + + cpdef bytes serialize(self, object value, int protocol_version): + cdef int v_length = len(value) + if v_length != self.vector_size: + raise ValueError( + "Expected sequence of size %d for vector of type %s and " + "dimension %d, observed sequence of length %d" % ( + self.vector_size, self.subtype.typename, + self.vector_size, v_length)) + + if self.type_code == 1: + return self._serialize_float(value) + elif self.type_code == 2: + return self._serialize_double(value) + elif self.type_code == 3: + return self._serialize_int32(value) + else: + return self._serialize_generic(value, protocol_version) + + cdef inline bytes _serialize_float(self, object values): + """Serialize a sequence of floats into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef float val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + _check_float_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline bytes _serialize_double(self, object values): + """Serialize a sequence of doubles into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 8 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef double val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + val = values[i] + src = &val + dst = buf + i * 8 + + if is_little_endian: + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(dst, src, 8) + + return result + + cdef inline bytes _serialize_int32(self, object values): + """Serialize a sequence of int32 values into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef int32_t val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + _check_int32_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline bytes _serialize_generic(self, object values, int protocol_version): + """Fallback: element-by-element Python serialization for non-optimized types.""" + serialized_size = self.subtype.serial_size() + buf = io.BytesIO() + for item in values: + item_bytes = self.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Generic serializer (fallback for all other types) +# --------------------------------------------------------------------------- + +cdef class GenericSerializer(Serializer): + """ + Wraps a generic cqltype for serialization, delegating to the Python-level + cqltype.serialize() classmethod. + """ + + cpdef bytes serialize(self, object value, int protocol_version): + return self.cqltype.serialize(value, protocol_version) + + def __repr__(self): + return "GenericSerializer(%s)" % (self.cqltype,) + + +# --------------------------------------------------------------------------- +# Lookup and factory +# --------------------------------------------------------------------------- + +cdef dict _ser_classes = {} + +cpdef Serializer find_serializer(cqltype): + """Find a serializer for a cqltype.""" + + # For VectorType, always use SerVectorType (it handles generic subtypes internally) + if issubclass(cqltype, cqltypes.VectorType): + return SerVectorType(cqltype) + + # For scalar types with dedicated serializers, look up by name + name = 'Ser' + cqltype.__name__ + cls = _ser_classes.get(name) + if cls is not None: + return cls(cqltype) + + # Fallback to generic + return GenericSerializer(cqltype) + + +def make_serializers(cqltypes_list): + """Create a list of Serializer objects for each given cqltype.""" + return [find_serializer(ct) for ct in cqltypes_list] + + +# Build the lookup dict for scalar serializers at module load time +_ser_classes['SerFloatType'] = SerFloatType +_ser_classes['SerDoubleType'] = SerDoubleType +_ser_classes['SerInt32Type'] = SerInt32Type diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 5416ac461d..59c238169b 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -1,4 +1,4 @@ -# Copyright DataStax, Inc. +# Copyright ScyllaDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,18 @@ # limitations under the License. import unittest +import struct import pytest from cassandra.encoder import Encoder from cassandra.protocol import ColumnMetadata -from cassandra.query import (bind_params, ValueSequence, PreparedStatement, - BoundStatement, UNSET_VALUE) +from cassandra.query import ( + bind_params, + ValueSequence, + PreparedStatement, + BoundStatement, + UNSET_VALUE, +) from cassandra.cqltypes import Int32Type from cassandra.util import OrderedDict @@ -26,7 +32,6 @@ class ParamBindingTest(unittest.TestCase): - def test_bind_sequence(self): result = bind_params("%s %s %s", (1, "a", 2.0), Encoder()) assert result == "1 'a' 2.0" @@ -48,18 +53,18 @@ def test_none_param(self): assert result == "NULL" def test_list_collection(self): - result = bind_params("%s", (['a', 'b', 'c'],), Encoder()) + result = bind_params("%s", (["a", "b", "c"],), Encoder()) assert result == "['a', 'b', 'c']" def test_set_collection(self): - result = bind_params("%s", (set(['a', 'b']),), Encoder()) + result = bind_params("%s", (set(["a", "b"]),), Encoder()) assert result in ("{'a', 'b'}", "{'b', 'a'}") def test_map_collection(self): vals = OrderedDict() - vals['a'] = 'a' - vals['b'] = 'b' - vals['c'] = 'c' + vals["a"] = "a" + vals["b"] = "b" + vals["c"] = "c" result = bind_params("%s", (vals,), Encoder()) assert result == "{'a': 'a', 'b': 'b', 'c': 'c'}" @@ -68,63 +73,70 @@ def test_quote_escaping(self): assert result == """'''ef''''ef"ef""ef'''""" def test_float_precision(self): - f = 3.4028234663852886e+38 + f = 3.4028234663852886e38 assert float(bind_params("%s", (f,), Encoder())) == f -class BoundStatementTestV3(unittest.TestCase): +class BoundStatementTestV3(unittest.TestCase): protocol_version = 3 @classmethod def setUpClass(cls): - column_metadata = [ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), - ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] - cls.prepared = PreparedStatement(column_metadata=column_metadata, - query_id=None, - routing_key_indexes=[1, 0], - query=None, - keyspace='keyspace', - protocol_version=cls.protocol_version, result_metadata=None, - result_metadata_id=None) + column_metadata = [ + ColumnMetadata("keyspace", "cf", "rk0", Int32Type), + ColumnMetadata("keyspace", "cf", "rk1", Int32Type), + ColumnMetadata("keyspace", "cf", "ck0", Int32Type), + ColumnMetadata("keyspace", "cf", "v0", Int32Type), + ] + cls.prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[1, 0], + query=None, + keyspace="keyspace", + protocol_version=cls.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): - values = (0, 0, 0, 'string not int') + values = (0, 0, 0, "string not int") with pytest.raises(TypeError) as exc: self.bound.bind(values) e = exc.value - assert 'v0' in str(e) - assert 'Int32Type' in str(e) - assert 'str' in str(e) + assert "v0" in str(e) + assert "Int32Type" in str(e) + assert "str" in str(e) - values = (['1', '2'], 0, 0, 0) + values = (["1", "2"], 0, 0, 0) with pytest.raises(TypeError) as exc: self.bound.bind(values) e = exc.value - assert 'rk0' in str(e) - assert 'Int32Type' in str(e) - assert 'list' in str(e) + assert "rk0" in str(e) + assert "Int32Type" in str(e) + assert "list" in str(e) def test_inherit_fetch_size(self): - keyspace = 'keyspace1' - column_family = 'cf1' + keyspace = "keyspace1" + column_family = "cf1" column_metadata = [ - ColumnMetadata(keyspace, column_family, 'foo1', Int32Type), - ColumnMetadata(keyspace, column_family, 'foo2', Int32Type) + ColumnMetadata(keyspace, column_family, "foo1", Int32Type), + ColumnMetadata(keyspace, column_family, "foo2", Int32Type), ] - prepared_statement = PreparedStatement(column_metadata=column_metadata, - query_id=None, - routing_key_indexes=[], - query=None, - keyspace=keyspace, - protocol_version=self.protocol_version, - result_metadata=None, - result_metadata_id=None) + prepared_statement = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace=keyspace, + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) assert 1234 == bound_statement.fetch_size @@ -134,21 +146,23 @@ def test_too_few_parameters_for_routing_key(self): self.prepared.bind((1,)) bound = self.prepared.bind((1, 2)) - assert bound.keyspace == 'keyspace' + assert bound.keyspace == "keyspace" def test_dict_missing_routing_key(self): with pytest.raises(KeyError): - self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0}) with pytest.raises(KeyError): - self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0}) def test_missing_value(self): with pytest.raises(KeyError): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0}) def test_extra_value(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict - assert self.bound.values == [b'\x00' * 4] * 4 # four encoded zeros + self.bound.bind( + {"rk0": 0, "rk1": 0, "ck0": 0, "v0": 0, "should_not_be_here": 123} + ) # okay to have extra keys in dict + assert self.bound.values == [b"\x00" * 4] * 4 # four encoded zeros with pytest.raises(ValueError): self.bound.bind((0, 0, 0, 0, 123)) @@ -158,19 +172,21 @@ def test_values_none(self): self.bound.bind(None) # prepared statement with no values - prepared_statement = PreparedStatement(column_metadata=[], - query_id=None, - routing_key_indexes=[], - query=None, - keyspace='whatever', - protocol_version=self.protocol_version, - result_metadata=None, - result_metadata_id=None) + prepared_statement = PreparedStatement( + column_metadata=[], + query_id=None, + routing_key_indexes=[], + query=None, + keyspace="whatever", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) bound = prepared_statement.bind(None) assertListEqual(bound.values, []) def test_bind_none(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": None}) assert self.bound.values[-1] == None old_values = self.bound.values @@ -180,7 +196,7 @@ def test_bind_none(self): def test_unset_value(self): with pytest.raises(ValueError): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE}) with pytest.raises(ValueError): self.bound.bind((0, 0, 0, UNSET_VALUE)) @@ -192,13 +208,13 @@ def test_dict_missing_routing_key(self): # in v4 it implicitly binds UNSET_VALUE for missing items, # UNSET_VALUE is ValueError for routing keys with pytest.raises(ValueError): - self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0}) with pytest.raises(ValueError): - self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0}) def test_missing_value(self): # in v4 missing values are UNSET_VALUE - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0}) assert self.bound.values[-1] == UNSET_VALUE old_values = self.bound.values @@ -207,7 +223,7 @@ def test_missing_value(self): assert self.bound.values[-1] == UNSET_VALUE def test_unset_value(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE}) assert self.bound.values[-1] == UNSET_VALUE self.bound.bind((0, 0, 0, UNSET_VALUE)) @@ -216,3 +232,286 @@ def test_unset_value(self): class BoundStatementTestV5(BoundStatementTestV4): protocol_version = 5 + + +class StubSerializer: + """Stub that mimics a Cython Serializer object for testing the fast path.""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + def serialize(self, value, protocol_version): + return self.cqltype.serialize(value, protocol_version) + + +class OverflowSerializer: + """Stub that raises struct.error, mimicking Cython int32 range check.""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + def serialize(self, value, protocol_version): + raise struct.error("'i' format requires -2147483648 <= number <= 2147483647") + + +class CythonBindPathTest(unittest.TestCase): + """Tests for the Cython serializer fast path in BoundStatement.bind(). + + These tests inject stub serializers via the PreparedStatement's cached + _cached_serializers attribute to exercise the Cython bind branch without + requiring compiled Cython. + """ + + protocol_version = 4 + + def _make_prepared(self, column_metadata, serializers=None): + """Create a PreparedStatement and inject serializers into its cache.""" + prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace="keyspace", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) + # Inject directly into the cache attribute used by the _serializers + # property, bypassing the lazy initialization. + prepared._cached_serializers = serializers + return prepared + + def test_cython_path_normal_serialization(self): + """Cython fast path produces the same result as the plain Python path.""" + column_metadata = [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ] + serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((42, -1)) + assert bound.values == [ + Int32Type.serialize(42, self.protocol_version), + Int32Type.serialize(-1, self.protocol_version), + ] + + def test_cython_path_none_value(self): + """None values pass through the Cython path without serialization.""" + column_metadata = [ColumnMetadata("keyspace", "cf", "c0", Int32Type)] + serializers = [StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None,)) + assert bound.values == [None] + + def test_cython_path_unset_value(self): + """UNSET_VALUE is handled correctly in the Cython fast path (v4+).""" + column_metadata = [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ] + serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((42, UNSET_VALUE)) + assert bound.values[0] == Int32Type.serialize(42, self.protocol_version) + assert bound.values[1] == UNSET_VALUE + + def test_cython_path_overflow_error_wrapped(self): + """struct.error from Cython int32 range check is caught and wrapped with column context.""" + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] + serializers = [OverflowSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind((2**31,)) + msg = str(exc.value) + assert "v0" in msg + assert "Int32Type" in msg + assert "int" in msg + + def test_cython_path_type_error_wrapped(self): + """TypeError from serializer is caught and wrapped with column context.""" + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] + serializers = [StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind(("not_an_int",)) + msg = str(exc.value) + assert "v0" in msg + assert "Int32Type" in msg + + def test_plain_path_overflow_error_wrapped(self): + """Out-of-range int in the plain Python path raises OverflowError (from + the Cython serializer) or struct.error (from the pure-Python fallback) + and is wrapped with column context.""" + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] + # Force the plain Python path (no Cython serializers) + prepared = self._make_prepared(column_metadata, serializers=None) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind((2**31,)) + msg = str(exc.value) + assert "v0" in msg + assert "Int32Type" in msg + + +class UnsetValueBindingTest(unittest.TestCase): + """Tests for UNSET_VALUE handling in all bind paths. + + These specifically test UNSET_VALUE in non-trailing positions to catch + index-management bugs in the pre-allocated values list. + """ + + protocol_version = 4 + + def _make_prepared( + self, column_metadata, serializers=None, routing_key_indexes=None + ): + prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=routing_key_indexes or [], + query=None, + keyspace="keyspace", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) + prepared._cached_serializers = serializers + return prepared + + def _three_column_metadata(self): + return [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ColumnMetadata("keyspace", "cf", "c2", Int32Type), + ] + + # --- Plain Python path (no serializers) --- + + def test_plain_unset_mid_list(self): + """UNSET_VALUE in the middle of a value list does not corrupt indices.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, UNSET_VALUE, 2)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + def test_plain_unset_first(self): + """UNSET_VALUE as the first value does not corrupt indices.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, 1, 2)) + assert bound.values == [ + UNSET_VALUE, + Int32Type.serialize(1, self.protocol_version), + Int32Type.serialize(2, self.protocol_version), + ] + + def test_plain_all_unset(self): + """All values are UNSET_VALUE.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE)) + assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE] + + def test_plain_unset_trailing(self): + """UNSET_VALUE as the last explicit value.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, 1, UNSET_VALUE)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + Int32Type.serialize(1, self.protocol_version), + UNSET_VALUE, + ] + + def test_plain_implicit_unset_fill(self): + """Fewer values than columns fills remaining with implicit UNSET_VALUE.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0,)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + UNSET_VALUE, + ] + + def test_plain_mixed_none_and_unset(self): + """Mix of None and UNSET_VALUE in the same bind.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None, UNSET_VALUE, 2)) + assert bound.values == [ + None, + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + # --- Cython serializer path (with stub serializers) --- + + def test_cython_unset_mid_list(self): + """UNSET_VALUE in the middle with Cython serializers does not corrupt indices.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, UNSET_VALUE, 2)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + def test_cython_unset_first(self): + """UNSET_VALUE as the first value with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, 1, 2)) + assert bound.values == [ + UNSET_VALUE, + Int32Type.serialize(1, self.protocol_version), + Int32Type.serialize(2, self.protocol_version), + ] + + def test_cython_all_unset(self): + """All values are UNSET_VALUE with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE)) + assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE] + + def test_cython_mixed_none_and_unset(self): + """Mix of None and UNSET_VALUE with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None, UNSET_VALUE, 2)) + assert bound.values == [ + None, + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ]