From e2dda9bbc7e2dcb109107e46954050e94a35b0e1 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 14 Mar 2026 13:07:55 +0200 Subject: [PATCH 1/5] (improvement) serializers: add Cython-optimized serialization for VectorType Add cassandra/serializers.pyx and cassandra/serializers.pxd implementing Cython-optimized serialization that mirrors the deserializers.pyx architecture. Implements type-specialized serializers for the three subtypes commonly used in vector columns: - SerFloatType: 4-byte big-endian IEEE 754 float - SerDoubleType: 8-byte big-endian double - SerInt32Type: 4-byte big-endian signed int32 SerVectorType pre-allocates a contiguous buffer and uses C-level byte swapping for float/double/int32 vectors, with a generic fallback for other subtypes. GenericSerializer delegates to the Python-level cqltype.serialize() classmethod. Range checks for float32 and int32 values prevent silent truncation from C-level casts, matching the behavior of struct.pack(). Factory functions find_serializer() and make_serializers() allow easy lookup and batch creation of serializers for column types. Benchmarks show ~30x speedup over the current io.BytesIO baseline and ~3x speedup over Python struct.pack for Vector serialization. No setup.py changes needed - the existing cassandra/*.pyx glob already picks up new .pyx files. --- cassandra/serializers.pxd | 20 ++ cassandra/serializers.pyx | 389 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 409 insertions(+) create mode 100644 cassandra/serializers.pxd create mode 100644 cassandra/serializers.pyx diff --git a/cassandra/serializers.pxd b/cassandra/serializers.pxd new file mode 100644 index 0000000000..60297077a8 --- /dev/null +++ b/cassandra/serializers.pxd @@ -0,0 +1,20 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef class Serializer: + # The cqltypes._CassandraType corresponding to this serializer + cdef object cqltype + + cpdef bytes serialize(self, object value, int protocol_version) diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx new file mode 100644 index 0000000000..0bf5ddf80b --- /dev/null +++ b/cassandra/serializers.pyx @@ -0,0 +1,389 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cython-optimized serializers for CQL types. + +Mirrors the architecture of deserializers.pyx. Currently implements +optimized serialization for: +- FloatType (4-byte big-endian float) +- DoubleType (8-byte big-endian double) +- Int32Type (4-byte big-endian signed int) +- VectorType (type-specialized for float/double/int32, generic fallback) + +For all other types, GenericSerializer delegates to the Python-level +cqltype.serialize() classmethod. +""" + +from libc.stdint cimport int32_t +from libc.string cimport memcpy +from libc.stdlib cimport malloc, free +from libc.float cimport FLT_MAX +from libc.math cimport isinf, isnan +from cpython.bytes cimport PyBytes_FromStringAndSize + +from cassandra import cqltypes + +cdef bint is_little_endian +from cassandra.util import is_little_endian + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + +cdef class Serializer: + """Cython-based serializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + cpdef bytes serialize(self, object value, int protocol_version): + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Float range check +# --------------------------------------------------------------------------- + +cdef inline void _check_float_range(double value) except *: + """Raise OverflowError for finite values outside float32 range. + + This matches the behavior of struct.pack('>f', value), which raises + OverflowError (via struct.error) for values that cannot be represented + as a 32-bit IEEE 754 float. inf, -inf, and nan pass through unchanged. + """ + if not isinf(value) and not isnan(value): + if value > FLT_MAX or value < -FLT_MAX: + raise OverflowError( + "Value %r too large for float32 (max %r)" % (value, FLT_MAX) + ) + + +# --------------------------------------------------------------------------- +# Int32 range check +# --------------------------------------------------------------------------- + +cdef inline void _check_int32_range(object value) except *: + """Raise OverflowError for values outside the signed int32 range. + + This matches the behavior of struct.pack('>i', value), which raises + struct.error for values outside [-2147483648, 2147483647]. The check + must be done on the Python int *before* the C-level cast, + which would silently truncate. + """ + if value > 2147483647 or value < -2147483648: + raise OverflowError( + "Value %r out of range for int32 " + "(must be between -2147483648 and 2147483647)" % (value,) + ) + + +# --------------------------------------------------------------------------- +# Scalar serializers +# --------------------------------------------------------------------------- + +cdef class SerFloatType(Serializer): + """Serialize a Python float to 4-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_float_range(value) + cdef float val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +cdef class SerDoubleType(Serializer): + """Serialize a Python float to 8-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + cdef double val = value + cdef char out[8] + cdef char *src = &val + + if is_little_endian: + out[0] = src[7] + out[1] = src[6] + out[2] = src[5] + out[3] = src[4] + out[4] = src[3] + out[5] = src[2] + out[6] = src[1] + out[7] = src[0] + else: + memcpy(out, src, 8) + + return PyBytes_FromStringAndSize(out, 8) + + +cdef class SerInt32Type(Serializer): + """Serialize a Python int to 4-byte big-endian signed int32.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_int32_range(value) + cdef int32_t val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +# --------------------------------------------------------------------------- +# Type detection helpers +# --------------------------------------------------------------------------- + +cdef inline bint _is_float_type(object subtype): + return subtype is cqltypes.FloatType or issubclass(subtype, cqltypes.FloatType) + +cdef inline bint _is_double_type(object subtype): + return subtype is cqltypes.DoubleType or issubclass(subtype, cqltypes.DoubleType) + +cdef inline bint _is_int32_type(object subtype): + return subtype is cqltypes.Int32Type or issubclass(subtype, cqltypes.Int32Type) + + +# --------------------------------------------------------------------------- +# VectorType serializer +# --------------------------------------------------------------------------- + +cdef class SerVectorType(Serializer): + """ + Optimized Cython serializer for VectorType. + + For float, double, and int32 vectors, pre-allocates a contiguous buffer + and uses C-level byte swapping. For other subtypes, falls back to + per-element Python serialization. + """ + + cdef int vector_size + cdef object subtype + # 0 = generic, 1 = float, 2 = double, 3 = int32 + cdef int type_code + + def __init__(self, cqltype): + super().__init__(cqltype) + self.vector_size = cqltype.vector_size + self.subtype = cqltype.subtype + + if _is_float_type(self.subtype): + self.type_code = 1 + elif _is_double_type(self.subtype): + self.type_code = 2 + elif _is_int32_type(self.subtype): + self.type_code = 3 + else: + self.type_code = 0 + + cpdef bytes serialize(self, object value, int protocol_version): + cdef int v_length = len(value) + if v_length != self.vector_size: + raise ValueError( + "Expected sequence of size %d for vector of type %s and " + "dimension %d, observed sequence of length %d" % ( + self.vector_size, self.subtype.typename, + self.vector_size, v_length)) + + if self.type_code == 1: + return self._serialize_float(value) + elif self.type_code == 2: + return self._serialize_double(value) + elif self.type_code == 3: + return self._serialize_int32(value) + else: + return self._serialize_generic(value, protocol_version) + + cdef inline bytes _serialize_float(self, object values): + """Serialize a list of floats into a contiguous big-endian buffer.""" + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + cdef char *buf = malloc(buf_size) + if buf == NULL: + raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef float val + cdef char *src + cdef char *dst + + try: + for i in range(self.vector_size): + _check_float_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return PyBytes_FromStringAndSize(buf, buf_size) + finally: + free(buf) + + cdef inline bytes _serialize_double(self, object values): + """Serialize a list of doubles into a contiguous big-endian buffer.""" + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 8 + if buf_size == 0: + return b"" + cdef char *buf = malloc(buf_size) + if buf == NULL: + raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef double val + cdef char *src + cdef char *dst + + try: + for i in range(self.vector_size): + val = values[i] + src = &val + dst = buf + i * 8 + + if is_little_endian: + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(dst, src, 8) + + return PyBytes_FromStringAndSize(buf, buf_size) + finally: + free(buf) + + cdef inline bytes _serialize_int32(self, object values): + """Serialize a list of int32 values into a contiguous big-endian buffer.""" + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + cdef char *buf = malloc(buf_size) + if buf == NULL: + raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef int32_t val + cdef char *src + cdef char *dst + + try: + for i in range(self.vector_size): + _check_int32_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return PyBytes_FromStringAndSize(buf, buf_size) + finally: + free(buf) + + cdef inline bytes _serialize_generic(self, object values, int protocol_version): + """Fallback: element-by-element Python serialization for non-optimized types.""" + import io + from cassandra.marshal import uvint_pack + + serialized_size = self.subtype.serial_size() + buf = io.BytesIO() + for item in values: + item_bytes = self.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Generic serializer (fallback for all other types) +# --------------------------------------------------------------------------- + +cdef class GenericSerializer(Serializer): + """ + Wraps a generic cqltype for serialization, delegating to the Python-level + cqltype.serialize() classmethod. + """ + + cpdef bytes serialize(self, object value, int protocol_version): + return self.cqltype.serialize(value, protocol_version) + + def __repr__(self): + return "GenericSerializer(%s)" % (self.cqltype,) + + +# --------------------------------------------------------------------------- +# Lookup and factory +# --------------------------------------------------------------------------- + +cdef dict _ser_classes = {} + +cpdef Serializer find_serializer(cqltype): + """Find a serializer for a cqltype.""" + + # For VectorType, always use SerVectorType (it handles generic subtypes internally) + if issubclass(cqltype, cqltypes.VectorType): + return SerVectorType(cqltype) + + # For scalar types with dedicated serializers, look up by name + name = 'Ser' + cqltype.__name__ + cls = _ser_classes.get(name) + if cls is not None: + return cls(cqltype) + + # Fallback to generic + return GenericSerializer(cqltype) + + +def make_serializers(cqltypes_list): + """Create a list of Serializer objects for each given cqltype.""" + return [find_serializer(ct) for ct in cqltypes_list] + + +# Build the lookup dict for scalar serializers at module load time +_ser_classes['SerFloatType'] = SerFloatType +_ser_classes['SerDoubleType'] = SerDoubleType +_ser_classes['SerInt32Type'] = SerInt32Type From c6c0584281b59b57c568ce96e6281a015d7c63dc Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 14 Mar 2026 13:21:29 +0200 Subject: [PATCH 2/5] (improvement) query: add Cython-aware serializer path in BoundStatement.bind() When Cython serializers (from cassandra.serializers) are available and no column encryption policy is active, BoundStatement.bind() now uses pre-built Serializer objects cached on the PreparedStatement instead of calling cqltype classmethods. This avoids per-value Python method dispatch overhead and enables the ~30x vector serialization speedup from the Cython serializers module. The bind loop is split into three paths: 1. Column encryption policy path (unchanged behavior) 2. Cython serializers path (new fast path) 3. Plain Python path (no CE, no Cython -- removes per-value ColDesc/CE check) Depends on PR #748 (Cython serializers module) and PR #630 (CE-policy bind split). --- cassandra/query.py | 115 ++++++++++++++++++++----- tests/unit/test_parameter_binding.py | 120 +++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 20 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..46f6073067 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -33,6 +33,12 @@ from cassandra.protocol import _UNSET_VALUE from cassandra.util import OrderedDict, _sanitize_identifiers +try: + from cassandra.serializers import make_serializers as _cython_make_serializers + _HAVE_CYTHON_SERIALIZERS = True +except ImportError: + _HAVE_CYTHON_SERIALIZERS = False + import logging log = logging.getLogger(__name__) @@ -474,6 +480,30 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self.is_idempotent = False self._is_lwt = is_lwt + @property + def _serializers(self): + """Lazily create and cache Cython serializers for column types. + + Returns a list of Serializer objects if Cython serializers are available + and there is no column encryption policy, otherwise returns None. + + The column_encryption_policy check is performed on every access (not + cached) so that serializers are correctly bypassed if a policy is set + after construction. + """ + if self.column_encryption_policy: + return None + try: + return self.__serializers + except AttributeError: + pass + if _HAVE_CYTHON_SERIALIZERS and self.column_metadata: + self.__serializers = _cython_make_serializers( + [col.type for col in self.column_metadata]) + else: + self.__serializers = None + return self.__serializers + @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata, @@ -532,6 +562,14 @@ def __str__(self): __repr__ = __str__ +def _raise_bind_serialize_error(col_spec, value, exc): + """Wrap serialization errors with column context for all bind loop paths.""" + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) + + class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. @@ -636,28 +674,65 @@ def bind(self, values): self.raw_values = values self.values = [] - for value, col_spec in zip(values, col_meta): - if value is None: - self.values.append(None) - elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() + if ce_policy: + # Column encryption path: check each column for CE policy + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + try: + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy.contains_column(col_desc) + if uses_ce: + col_type = ce_policy.column_type(col_desc) + col_bytes = col_type.serialize(value, proto_version) + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + else: + col_bytes = col_spec.type.serialize(value, proto_version) + self.values.append(col_bytes) + # OverflowError: Cython int32/float casts may raise on out-of-range values + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) + else: + # Fast path: no column encryption, use Cython serializers if available + serializers = self.prepared_statement._serializers + if serializers is not None: + for ser, value, col_spec in zip(serializers, values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + else: + try: + col_bytes = ser.serialize(value, proto_version) + self.values.append(col_bytes) + # OverflowError: Cython int32/float casts may raise on out-of-range values + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) else: - try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) - uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) - if uses_ce: - col_bytes = ce_policy.encrypt(col_desc, col_bytes) - self.values.append(col_bytes) - except (TypeError, struct.error) as exc: - actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) - raise TypeError(message) + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + else: + try: + col_bytes = col_spec.type.serialize(value, proto_version) + self.values.append(col_bytes) + # OverflowError: Cython int32/float casts may raise on out-of-range values + except (TypeError, struct.error, OverflowError) as exc: + _raise_bind_serialize_error(col_spec, value, exc) if proto_version >= 4: diff = col_meta_len - len(self.values) diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 5416ac461d..88b095d54f 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -216,3 +216,123 @@ def test_unset_value(self): class BoundStatementTestV5(BoundStatementTestV4): protocol_version = 5 + + +class StubSerializer: + """Stub that mimics a Cython Serializer object for testing the fast path.""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + def serialize(self, value, protocol_version): + return self.cqltype.serialize(value, protocol_version) + + +class OverflowSerializer: + """Stub that raises OverflowError, mimicking Cython cast overflow.""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + def serialize(self, value, protocol_version): + raise OverflowError('value too large to convert to int32_t') + + +class CythonBindPathTest(unittest.TestCase): + """Tests for the Cython serializer fast path in BoundStatement.bind(). + + These tests inject stub serializers via the PreparedStatement's cached + __serializers attribute to exercise the Cython bind branch without + requiring compiled Cython. + """ + + protocol_version = 4 + + def _make_prepared(self, column_metadata, serializers=None): + """Create a PreparedStatement and inject serializers into its cache.""" + prepared = PreparedStatement(column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace='keyspace', + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None) + # Inject directly into the name-mangled cache attribute used by + # the _serializers property, bypassing the lazy initialization. + prepared._PreparedStatement__serializers = serializers + return prepared + + def test_cython_path_normal_serialization(self): + """Cython fast path produces the same result as the plain Python path.""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'c1', Int32Type)] + serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((42, -1)) + assert bound.values == [Int32Type.serialize(42, self.protocol_version), + Int32Type.serialize(-1, self.protocol_version)] + + def test_cython_path_none_value(self): + """None values pass through the Cython path without serialization.""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type)] + serializers = [StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None,)) + assert bound.values == [None] + + def test_cython_path_unset_value(self): + """UNSET_VALUE is handled correctly in the Cython fast path (v4+).""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'c1', Int32Type)] + serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + bound.bind((42, UNSET_VALUE)) + assert bound.values[0] == Int32Type.serialize(42, self.protocol_version) + assert bound.values[1] == UNSET_VALUE + + def test_cython_path_overflow_error_wrapped(self): + """OverflowError from Cython cast is caught and wrapped with column context.""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + serializers = [OverflowSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind((2**31,)) + msg = str(exc.value) + assert 'v0' in msg + assert 'Int32Type' in msg + assert 'int' in msg + + def test_cython_path_type_error_wrapped(self): + """TypeError from serializer is caught and wrapped with column context.""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + serializers = [StubSerializer(Int32Type)] + prepared = self._make_prepared(column_metadata, serializers) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind(('not_an_int',)) + msg = str(exc.value) + assert 'v0' in msg + assert 'Int32Type' in msg + + def test_plain_path_overflow_error_wrapped(self): + """OverflowError in the plain Python path is also caught and wrapped.""" + column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + # Force the plain Python path (no Cython serializers) + prepared = self._make_prepared(column_metadata, serializers=None) + + bound = BoundStatement(prepared_statement=prepared) + with pytest.raises(TypeError) as exc: + bound.bind((2**31,)) + msg = str(exc.value) + assert 'v0' in msg + assert 'Int32Type' in msg From 25433dda5b26001a27f38008c2da7cc9d2b86fa5 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 19 Mar 2026 12:46:39 +0200 Subject: [PATCH 3/5] (improvement) query: pre-allocate BoundStatement.values list to avoid repeated .append() growth --- cassandra/query.py | 521 +++++++++++++++++++-------- tests/unit/test_parameter_binding.py | 352 +++++++++++++----- 2 files changed, 642 insertions(+), 231 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index 46f6073067..eec42c8de9 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -35,11 +35,13 @@ try: from cassandra.serializers import make_serializers as _cython_make_serializers + _HAVE_CYTHON_SERIALIZERS = True except ImportError: _HAVE_CYTHON_SERIALIZERS = False import logging + log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE @@ -55,9 +57,9 @@ Only valid when using native protocol v4+ """ -NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') -START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') -END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") _clean_name_cache = {} @@ -66,7 +68,9 @@ def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: - clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) _clean_name_cache[name] = clean return clean @@ -89,6 +93,7 @@ def tuple_factory(colnames, rows): """ return rows + class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an @@ -96,6 +101,7 @@ class PseudoNamedTupleRow(object): but otherwise do not attempt to implement the full namedtuple or iterable interface. """ + def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) @@ -110,8 +116,7 @@ def __iter__(self): return iter(self._tuple) def __repr__(self): - return '{t}({od})'.format(t=self.__class__.__name__, - od=self._dict) + return "{t}({od})".format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): @@ -119,8 +124,7 @@ def pseudo_namedtuple_factory(colnames, rows): Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ - return [PseudoNamedTupleRow(od) - for od in ordered_dict_factory(colnames, rows)] + return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] def named_tuple_factory(colnames, rows): @@ -154,7 +158,7 @@ def named_tuple_factory(colnames, rows): """ clean_column_names = map(_clean_column_name, colnames) try: - Row = namedtuple('Row', clean_column_names) + Row = namedtuple("Row", clean_column_names) except SyntaxError: warnings.warn( "Failed creating namedtuple for a result because there were too " @@ -165,19 +169,23 @@ def named_tuple_factory(colnames, rows): "values on row objects, Upgrade to Python 3.7, or use a different " "row factory. (column names: {colnames})".format( substitute_factory_name=pseudo_namedtuple_factory.__name__, - colnames=colnames + colnames=colnames, ) ) return pseudo_namedtuple_factory(colnames, rows) except Exception: - clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt - log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " - "(see Python 'namedtuple' documentation for details on name rules). " - "Results will be returned with positional names. " - "Avoid this by choosing different names, using SELECT \"\" AS aliases, " - "or specifying a different row_factory on your Session" % - (colnames, clean_column_names)) - Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + clean_column_names = list( + map(_clean_column_name, colnames) + ) # create list because py3 map object will be consumed by first attempt + log.warning( + "Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + 'Avoid this by choosing different names, using SELECT "" AS aliases, ' + "or specifying a different row_factory on your Session" + % (colnames, clean_column_names) + ) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] @@ -282,11 +290,24 @@ class Statement(object): _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False, table=None): - if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors - raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + def __init__( + self, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + table=None, + ): + if retry_policy and not hasattr( + retry_policy, "on_read_timeout" + ): # just checking one method to detect positional parameter errors + raise ValueError( + "retry_policy should implement cassandra.policies.RetryPolicy" + ) if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: @@ -335,17 +356,20 @@ def _del_routing_key(self): If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. - """) + """, + ) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL") + "or ConsistencyLevel.LOCAL_SERIAL" + ) self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): @@ -390,7 +414,8 @@ def is_lwt(self): conditional statements. .. versionadded:: 2.0.0 - """) + """, + ) class SimpleStatement(Statement): @@ -398,9 +423,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None, is_idempotent=False): + def __init__( + self, + query_string, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + ): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -408,8 +442,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + is_idempotent, + ) self._query_string = query_string @property @@ -417,9 +460,14 @@ def query_string(self): return self._query_string def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -448,7 +496,7 @@ class PreparedStatement(object): A note about * in prepared statements """ - column_metadata = None #TODO: make this bind_metadata in next major + column_metadata = None # TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None @@ -465,9 +513,19 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, - is_lwt=False, column_encryption_policy=None): + def __init__( + self, + column_metadata, + query_id, + routing_key_indexes, + query, + keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt=False, + column_encryption_policy=None, + ): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -499,19 +557,40 @@ def _serializers(self): pass if _HAVE_CYTHON_SERIALIZERS and self.column_metadata: self.__serializers = _cython_make_serializers( - [col.type for col in self.column_metadata]) + [col.type for col in self.column_metadata] + ) else: self.__serializers = None return self.__serializers @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy=None): + def from_message( + cls, + query_id, + column_metadata, + pk_indexes, + cluster_metadata, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy=None, + ): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + return PreparedStatement( + column_metadata, + query_id, + None, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) if pk_indexes: routing_key_indexes = pk_indexes @@ -526,18 +605,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) + statement_indexes = dict( + (c.name, i) for i, c in enumerate(column_metadata) + ) # a list of which indexes in the statement correspond to partition key items try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: # we're missing a partition key component in the prepared - pass # statement; just leave routing_key_indexes as None - - return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + routing_key_indexes = [ + statement_indexes[c.name] for c in partition_key_columns + ] + except ( + KeyError + ): # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement( + column_metadata, + query_id, + routing_key_indexes, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) def bind(self, values): """ @@ -549,24 +642,33 @@ def bind(self, values): def is_routing_key_index(self, i): if self._routing_key_index_set is None: - self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + self._routing_key_index_set = ( + set(self.routing_key_indexes) if self.routing_key_indexes else set() + ) return i in self._routing_key_index_set def is_lwt(self): return self._is_lwt def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ def _raise_bind_serialize_error(col_spec, value, exc): """Wrap serialization errors with column context for all bind loop paths.""" actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + message = ( + 'Received an argument of invalid type for column "%s". ' + "Expected: %s, Got: %s; (%s)" % (col_spec.name, col_spec.type, actual_type, exc) + ) raise TypeError(message) @@ -586,9 +688,17 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + def __init__( + self, + prepared_statement, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + ): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. @@ -609,9 +719,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None self.keyspace = meta[0].keyspace_name self.table = meta[0].table_name - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, - prepared_statement.is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + prepared_statement.is_idempotent, + ) def bind(self, values): """ @@ -653,40 +771,53 @@ def bind(self, values): values.append(UNSET_VALUE) else: raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + "Column name `%s` not found in bound dict." % (col.name) + ) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( - "Too many arguments provided to bind() (got %d, expected %d)" % - (len(values), len(col_meta))) + "Too many arguments provided to bind() (got %d, expected %d)" + % (len(values), len(col_meta)) + ) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): + if ( + proto_version < 4 + and self.prepared_statement.routing_key_indexes + and value_len < len(self.prepared_statement.routing_key_indexes) + ): raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) + "Too few arguments provided to bind() (got %d, required %d for routing key)" + % (value_len, len(self.prepared_statement.routing_key_indexes)) + ) self.raw_values = values - self.values = [] + # Pre-allocate to avoid repeated list growth reallocations + self.values = [None] * col_meta_len + idx = 0 if ce_policy: # Column encryption path: check each column for CE policy for value, col_spec in zip(values, col_meta): if value is None: - self.values.append(None) + self.values[idx] = None elif value is UNSET_VALUE: if proto_version >= 4: - self._append_unset_value() + idx = self._append_unset_value(idx) + continue else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + col_desc = ColDesc( + col_spec.keyspace_name, col_spec.table_name, col_spec.name + ) uses_ce = ce_policy.contains_column(col_desc) if uses_ce: col_type = ce_policy.column_type(col_desc) @@ -694,60 +825,78 @@ def bind(self, values): col_bytes = ce_policy.encrypt(col_desc, col_bytes) else: col_bytes = col_spec.type.serialize(value, proto_version) - self.values.append(col_bytes) + self.values[idx] = col_bytes # OverflowError: Cython int32/float casts may raise on out-of-range values except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 else: # Fast path: no column encryption, use Cython serializers if available serializers = self.prepared_statement._serializers if serializers is not None: for ser, value, col_spec in zip(serializers, values, col_meta): if value is None: - self.values.append(None) + self.values[idx] = None elif value is UNSET_VALUE: if proto_version >= 4: - self._append_unset_value() + idx = self._append_unset_value(idx) + continue else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: col_bytes = ser.serialize(value, proto_version) - self.values.append(col_bytes) + self.values[idx] = col_bytes # OverflowError: Cython int32/float casts may raise on out-of-range values except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 else: for value, col_spec in zip(values, col_meta): if value is None: - self.values.append(None) + self.values[idx] = None elif value is UNSET_VALUE: if proto_version >= 4: - self._append_unset_value() + idx = self._append_unset_value(idx) + continue else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: col_bytes = col_spec.type.serialize(value, proto_version) - self.values.append(col_bytes) + self.values[idx] = col_bytes # OverflowError: Cython int32/float casts may raise on out-of-range values except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) + idx += 1 if proto_version >= 4: - diff = col_meta_len - len(self.values) - if diff: - for _ in range(diff): - self._append_unset_value() + # Fill remaining unbound columns with UNSET_VALUE (v4+ feature). + # The pre-allocated list already has slots for these, so index + # assignment works directly without trimming first. + for i in range(idx, col_meta_len): + idx = self._append_unset_value(idx) + elif idx < col_meta_len: + # Pre-v4: trim trailing unused slots (no UNSET_VALUE support) + self.values = self.values[:idx] return self - def _append_unset_value(self): - next_index = len(self.values) - if self.prepared_statement.is_routing_key_index(next_index): - col_meta = self.prepared_statement.column_metadata[next_index] - raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) - self.values.append(UNSET_VALUE) + def _append_unset_value(self, idx): + if self.prepared_statement.is_routing_key_index(idx): + col_meta = self.prepared_statement.column_metadata[idx] + raise ValueError( + "Cannot bind UNSET_VALUE as a part of the routing key '%s'" + % col_meta.name + ) + self.values[idx] = UNSET_VALUE + return idx + 1 @property def routing_key(self): @@ -761,7 +910,9 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) + self._routing_key = b"".join( + self._key_parts_packed(self.values[i] for i in routing_indexes) + ) return self._routing_key @@ -769,9 +920,15 @@ def is_lwt(self): return self.prepared_statement.is_lwt() def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.prepared_statement.query_string, self.raw_values, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.prepared_statement.query_string, + self.raw_values, + consistency, + ) + __repr__ = __str__ @@ -806,7 +963,7 @@ def __str__(self): return self.name def __repr__(self): - return "BatchType.%s" % (self.name, ) + return "BatchType.%s" % (self.name,) BatchType.LOGGED = BatchType("LOGGED", 0) @@ -838,9 +995,15 @@ class BatchStatement(Statement): _session = None _is_lwt = False - def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, - session=None, custom_payload=None): + def __init__( + self, + batch_type=BatchType.LOGGED, + retry_policy=None, + consistency_level=None, + serial_consistency_level=None, + session=None, + custom_payload=None, + ): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -888,8 +1051,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, self.batch_type = batch_type self._statements_and_parameters = [] self._session = session - Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + Statement.__init__( + self, + retry_policy=retry_policy, + consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, + custom_payload=custom_payload, + ) def clear(self): """ @@ -928,11 +1096,14 @@ def add(self, statement, parameters=None): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " - "to BatchStatement.add()") + "to BatchStatement.add()" + ) self._update_state(statement) if statement.is_lwt(): self._is_lwt = True - self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + self._add_statement_and_params( + True, statement.prepared_statement.query_id, statement.values + ) else: # it must be a SimpleStatement query_string = statement.query_string @@ -956,7 +1127,9 @@ def add_all(self, statements, parameters): def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: - raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + raise ValueError( + "Batch statement cannot contain more than %d statements." % 0xFFFF + ) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): @@ -982,9 +1155,15 @@ def __len__(self): return len(self._statements_and_parameters) def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.batch_type, len(self), consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return "" % ( + self.batch_type, + len(self), + consistency, + ) + __repr__ = __str__ @@ -1006,7 +1185,9 @@ def __str__(self): def bind_params(query, params, encoder): if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + return query % dict( + (k, encoder.cql_encode_all_types(v)) for k, v in params.items() + ) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -1015,6 +1196,7 @@ class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ + pass @@ -1075,7 +1257,9 @@ class QueryTrace(object): _session = None - _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_SESSIONS_FORMAT = ( + "SELECT * FROM system_traces.sessions WHERE session_id = %s" + ) _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 @@ -1104,18 +1288,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( - "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." + % (max_wait,) + ) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) - metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout + metadata_request_timeout = ( + self._session.cluster.control_connection + and self._session.cluster.control_connection._metadata_request_timeout + ) session_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_SESSIONS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), + (self.trace_id,), + time_spent, + max_wait, + ) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries session_row = session_results.one() if session_results else None - is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + is_complete = ( + session_row is not None + and session_row.duration is not None + and session_row.started_at is not None + ) if not session_results or (wait_for_complete and not is_complete): - time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + time.sleep(self._BASE_RETRY_SLEEP * (2**attempt)) attempt += 1 continue if is_complete: @@ -1124,29 +1326,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) if is_complete else None + self.duration = ( + timedelta(microseconds=session_row.duration) if is_complete else None + ) self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 - self.client = getattr(session_row, 'client', None) + self.client = getattr(session_row, "client", None) - log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + log.debug( + "Attempting to fetch trace events for trace ID: %s", self.trace_id + ) time_spent = time.time() - start event_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout), - consistency_level=query_cl), + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_EVENTS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), (self.trace_id,), time_spent, - max_wait) + max_wait, + ) log.debug("Fetched trace events for trace ID: %s", self.trace_id) - self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) - for r in event_results) + self.events = tuple( + TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results + ) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future( + query, parameters, trace=False, custom_payload=None, timeout=timeout + ) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() @@ -1154,12 +1369,22 @@ def _execute(self, query, parameters, time_spent, max_wait): try: return future.result() except OperationTimedOut: - raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + raise TraceUnavailable( + "Trace information was not available within %f seconds" % (max_wait,) + ) def __str__(self): - return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ - % (self.request_type, self.trace_id, self.coordinator, self.started_at, - self.duration, self.parameters) + return ( + "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" + % ( + self.request_type, + self.trace_id, + self.coordinator, + self.started_at, + self.duration, + self.parameters, + ) + ) class TraceEvent(object): @@ -1196,7 +1421,9 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) + self.datetime = datetime.fromtimestamp( + unix_time_from_uuid1(timeuuid), tz=timezone.utc + ) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -1205,7 +1432,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.thread_name = thread_name def __str__(self): - return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + return "%s on %s[%s] at %s" % ( + self.description, + self.source, + self.thread_name, + self.datetime, + ) # TODO remove next major since we can target using the `host` attribute of session.execute @@ -1214,9 +1446,12 @@ class HostTargetingStatement(object): Wraps any query statement and attaches a target host, making it usable in a targeted LBP without modifying the user's statement. """ + def __init__(self, inner_statement, target_host): - self.__class__ = type(inner_statement.__class__.__name__, - (self.__class__, inner_statement.__class__), - {}) - self.__dict__ = inner_statement.__dict__ - self.target_host = target_host + self.__class__ = type( + inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}, + ) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 88b095d54f..8be29a9cd2 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -17,8 +17,13 @@ from cassandra.encoder import Encoder from cassandra.protocol import ColumnMetadata -from cassandra.query import (bind_params, ValueSequence, PreparedStatement, - BoundStatement, UNSET_VALUE) +from cassandra.query import ( + bind_params, + ValueSequence, + PreparedStatement, + BoundStatement, + UNSET_VALUE, +) from cassandra.cqltypes import Int32Type from cassandra.util import OrderedDict @@ -26,7 +31,6 @@ class ParamBindingTest(unittest.TestCase): - def test_bind_sequence(self): result = bind_params("%s %s %s", (1, "a", 2.0), Encoder()) assert result == "1 'a' 2.0" @@ -48,18 +52,18 @@ def test_none_param(self): assert result == "NULL" def test_list_collection(self): - result = bind_params("%s", (['a', 'b', 'c'],), Encoder()) + result = bind_params("%s", (["a", "b", "c"],), Encoder()) assert result == "['a', 'b', 'c']" def test_set_collection(self): - result = bind_params("%s", (set(['a', 'b']),), Encoder()) + result = bind_params("%s", (set(["a", "b"]),), Encoder()) assert result in ("{'a', 'b'}", "{'b', 'a'}") def test_map_collection(self): vals = OrderedDict() - vals['a'] = 'a' - vals['b'] = 'b' - vals['c'] = 'c' + vals["a"] = "a" + vals["b"] = "b" + vals["c"] = "c" result = bind_params("%s", (vals,), Encoder()) assert result == "{'a': 'a', 'b': 'b', 'c': 'c'}" @@ -68,63 +72,70 @@ def test_quote_escaping(self): assert result == """'''ef''''ef"ef""ef'''""" def test_float_precision(self): - f = 3.4028234663852886e+38 + f = 3.4028234663852886e38 assert float(bind_params("%s", (f,), Encoder())) == f -class BoundStatementTestV3(unittest.TestCase): +class BoundStatementTestV3(unittest.TestCase): protocol_version = 3 @classmethod def setUpClass(cls): - column_metadata = [ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), - ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] - cls.prepared = PreparedStatement(column_metadata=column_metadata, - query_id=None, - routing_key_indexes=[1, 0], - query=None, - keyspace='keyspace', - protocol_version=cls.protocol_version, result_metadata=None, - result_metadata_id=None) + column_metadata = [ + ColumnMetadata("keyspace", "cf", "rk0", Int32Type), + ColumnMetadata("keyspace", "cf", "rk1", Int32Type), + ColumnMetadata("keyspace", "cf", "ck0", Int32Type), + ColumnMetadata("keyspace", "cf", "v0", Int32Type), + ] + cls.prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[1, 0], + query=None, + keyspace="keyspace", + protocol_version=cls.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): - values = (0, 0, 0, 'string not int') + values = (0, 0, 0, "string not int") with pytest.raises(TypeError) as exc: self.bound.bind(values) e = exc.value - assert 'v0' in str(e) - assert 'Int32Type' in str(e) - assert 'str' in str(e) + assert "v0" in str(e) + assert "Int32Type" in str(e) + assert "str" in str(e) - values = (['1', '2'], 0, 0, 0) + values = (["1", "2"], 0, 0, 0) with pytest.raises(TypeError) as exc: self.bound.bind(values) e = exc.value - assert 'rk0' in str(e) - assert 'Int32Type' in str(e) - assert 'list' in str(e) + assert "rk0" in str(e) + assert "Int32Type" in str(e) + assert "list" in str(e) def test_inherit_fetch_size(self): - keyspace = 'keyspace1' - column_family = 'cf1' + keyspace = "keyspace1" + column_family = "cf1" column_metadata = [ - ColumnMetadata(keyspace, column_family, 'foo1', Int32Type), - ColumnMetadata(keyspace, column_family, 'foo2', Int32Type) + ColumnMetadata(keyspace, column_family, "foo1", Int32Type), + ColumnMetadata(keyspace, column_family, "foo2", Int32Type), ] - prepared_statement = PreparedStatement(column_metadata=column_metadata, - query_id=None, - routing_key_indexes=[], - query=None, - keyspace=keyspace, - protocol_version=self.protocol_version, - result_metadata=None, - result_metadata_id=None) + prepared_statement = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace=keyspace, + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) assert 1234 == bound_statement.fetch_size @@ -134,21 +145,23 @@ def test_too_few_parameters_for_routing_key(self): self.prepared.bind((1,)) bound = self.prepared.bind((1, 2)) - assert bound.keyspace == 'keyspace' + assert bound.keyspace == "keyspace" def test_dict_missing_routing_key(self): with pytest.raises(KeyError): - self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0}) with pytest.raises(KeyError): - self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0}) def test_missing_value(self): with pytest.raises(KeyError): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0}) def test_extra_value(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict - assert self.bound.values == [b'\x00' * 4] * 4 # four encoded zeros + self.bound.bind( + {"rk0": 0, "rk1": 0, "ck0": 0, "v0": 0, "should_not_be_here": 123} + ) # okay to have extra keys in dict + assert self.bound.values == [b"\x00" * 4] * 4 # four encoded zeros with pytest.raises(ValueError): self.bound.bind((0, 0, 0, 0, 123)) @@ -158,19 +171,21 @@ def test_values_none(self): self.bound.bind(None) # prepared statement with no values - prepared_statement = PreparedStatement(column_metadata=[], - query_id=None, - routing_key_indexes=[], - query=None, - keyspace='whatever', - protocol_version=self.protocol_version, - result_metadata=None, - result_metadata_id=None) + prepared_statement = PreparedStatement( + column_metadata=[], + query_id=None, + routing_key_indexes=[], + query=None, + keyspace="whatever", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) bound = prepared_statement.bind(None) assertListEqual(bound.values, []) def test_bind_none(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": None}) assert self.bound.values[-1] == None old_values = self.bound.values @@ -180,7 +195,7 @@ def test_bind_none(self): def test_unset_value(self): with pytest.raises(ValueError): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE}) with pytest.raises(ValueError): self.bound.bind((0, 0, 0, UNSET_VALUE)) @@ -192,13 +207,13 @@ def test_dict_missing_routing_key(self): # in v4 it implicitly binds UNSET_VALUE for missing items, # UNSET_VALUE is ValueError for routing keys with pytest.raises(ValueError): - self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk0": 0, "ck0": 0, "v0": 0}) with pytest.raises(ValueError): - self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) + self.bound.bind({"rk1": 0, "ck0": 0, "v0": 0}) def test_missing_value(self): # in v4 missing values are UNSET_VALUE - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0}) assert self.bound.values[-1] == UNSET_VALUE old_values = self.bound.values @@ -207,7 +222,7 @@ def test_missing_value(self): assert self.bound.values[-1] == UNSET_VALUE def test_unset_value(self): - self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.bound.bind({"rk0": 0, "rk1": 0, "ck0": 0, "v0": UNSET_VALUE}) assert self.bound.values[-1] == UNSET_VALUE self.bound.bind((0, 0, 0, UNSET_VALUE)) @@ -235,7 +250,7 @@ def __init__(self, cqltype): self.cqltype = cqltype def serialize(self, value, protocol_version): - raise OverflowError('value too large to convert to int32_t') + raise OverflowError("value too large to convert to int32_t") class CythonBindPathTest(unittest.TestCase): @@ -250,14 +265,16 @@ class CythonBindPathTest(unittest.TestCase): def _make_prepared(self, column_metadata, serializers=None): """Create a PreparedStatement and inject serializers into its cache.""" - prepared = PreparedStatement(column_metadata=column_metadata, - query_id=None, - routing_key_indexes=[], - query=None, - keyspace='keyspace', - protocol_version=self.protocol_version, - result_metadata=None, - result_metadata_id=None) + prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace="keyspace", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) # Inject directly into the name-mangled cache attribute used by # the _serializers property, bypassing the lazy initialization. prepared._PreparedStatement__serializers = serializers @@ -265,19 +282,23 @@ def _make_prepared(self, column_metadata, serializers=None): def test_cython_path_normal_serialization(self): """Cython fast path produces the same result as the plain Python path.""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'c1', Int32Type)] + column_metadata = [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ] serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) bound = BoundStatement(prepared_statement=prepared) bound.bind((42, -1)) - assert bound.values == [Int32Type.serialize(42, self.protocol_version), - Int32Type.serialize(-1, self.protocol_version)] + assert bound.values == [ + Int32Type.serialize(42, self.protocol_version), + Int32Type.serialize(-1, self.protocol_version), + ] def test_cython_path_none_value(self): """None values pass through the Cython path without serialization.""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type)] + column_metadata = [ColumnMetadata("keyspace", "cf", "c0", Int32Type)] serializers = [StubSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) @@ -287,8 +308,10 @@ def test_cython_path_none_value(self): def test_cython_path_unset_value(self): """UNSET_VALUE is handled correctly in the Cython fast path (v4+).""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'c0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'c1', Int32Type)] + column_metadata = [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ] serializers = [StubSerializer(Int32Type), StubSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) @@ -299,7 +322,7 @@ def test_cython_path_unset_value(self): def test_cython_path_overflow_error_wrapped(self): """OverflowError from Cython cast is caught and wrapped with column context.""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] serializers = [OverflowSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) @@ -307,26 +330,26 @@ def test_cython_path_overflow_error_wrapped(self): with pytest.raises(TypeError) as exc: bound.bind((2**31,)) msg = str(exc.value) - assert 'v0' in msg - assert 'Int32Type' in msg - assert 'int' in msg + assert "v0" in msg + assert "Int32Type" in msg + assert "int" in msg def test_cython_path_type_error_wrapped(self): """TypeError from serializer is caught and wrapped with column context.""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] serializers = [StubSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) bound = BoundStatement(prepared_statement=prepared) with pytest.raises(TypeError) as exc: - bound.bind(('not_an_int',)) + bound.bind(("not_an_int",)) msg = str(exc.value) - assert 'v0' in msg - assert 'Int32Type' in msg + assert "v0" in msg + assert "Int32Type" in msg def test_plain_path_overflow_error_wrapped(self): """OverflowError in the plain Python path is also caught and wrapped.""" - column_metadata = [ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] # Force the plain Python path (no Cython serializers) prepared = self._make_prepared(column_metadata, serializers=None) @@ -334,5 +357,158 @@ def test_plain_path_overflow_error_wrapped(self): with pytest.raises(TypeError) as exc: bound.bind((2**31,)) msg = str(exc.value) - assert 'v0' in msg - assert 'Int32Type' in msg + assert "v0" in msg + assert "Int32Type" in msg + + +class UnsetValueBindingTest(unittest.TestCase): + """Tests for UNSET_VALUE handling in all bind paths. + + These specifically test UNSET_VALUE in non-trailing positions to catch + index-management bugs in the pre-allocated values list. + """ + + protocol_version = 4 + + def _make_prepared( + self, column_metadata, serializers=None, routing_key_indexes=None + ): + prepared = PreparedStatement( + column_metadata=column_metadata, + query_id=None, + routing_key_indexes=routing_key_indexes or [], + query=None, + keyspace="keyspace", + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None, + ) + prepared._PreparedStatement__serializers = serializers + return prepared + + def _three_column_metadata(self): + return [ + ColumnMetadata("keyspace", "cf", "c0", Int32Type), + ColumnMetadata("keyspace", "cf", "c1", Int32Type), + ColumnMetadata("keyspace", "cf", "c2", Int32Type), + ] + + # --- Plain Python path (no serializers) --- + + def test_plain_unset_mid_list(self): + """UNSET_VALUE in the middle of a value list does not corrupt indices.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, UNSET_VALUE, 2)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + def test_plain_unset_first(self): + """UNSET_VALUE as the first value does not corrupt indices.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, 1, 2)) + assert bound.values == [ + UNSET_VALUE, + Int32Type.serialize(1, self.protocol_version), + Int32Type.serialize(2, self.protocol_version), + ] + + def test_plain_all_unset(self): + """All values are UNSET_VALUE.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE)) + assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE] + + def test_plain_unset_trailing(self): + """UNSET_VALUE as the last explicit value.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, 1, UNSET_VALUE)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + Int32Type.serialize(1, self.protocol_version), + UNSET_VALUE, + ] + + def test_plain_implicit_unset_fill(self): + """Fewer values than columns fills remaining with implicit UNSET_VALUE.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0,)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + UNSET_VALUE, + ] + + def test_plain_mixed_none_and_unset(self): + """Mix of None and UNSET_VALUE in the same bind.""" + col_meta = self._three_column_metadata() + prepared = self._make_prepared(col_meta, serializers=None) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None, UNSET_VALUE, 2)) + assert bound.values == [ + None, + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + # --- Cython serializer path (with stub serializers) --- + + def test_cython_unset_mid_list(self): + """UNSET_VALUE in the middle with Cython serializers does not corrupt indices.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((0, UNSET_VALUE, 2)) + assert bound.values == [ + Int32Type.serialize(0, self.protocol_version), + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] + + def test_cython_unset_first(self): + """UNSET_VALUE as the first value with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, 1, 2)) + assert bound.values == [ + UNSET_VALUE, + Int32Type.serialize(1, self.protocol_version), + Int32Type.serialize(2, self.protocol_version), + ] + + def test_cython_all_unset(self): + """All values are UNSET_VALUE with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((UNSET_VALUE, UNSET_VALUE, UNSET_VALUE)) + assert bound.values == [UNSET_VALUE, UNSET_VALUE, UNSET_VALUE] + + def test_cython_mixed_none_and_unset(self): + """Mix of None and UNSET_VALUE with Cython serializers.""" + col_meta = self._three_column_metadata() + serializers = [StubSerializer(Int32Type)] * 3 + prepared = self._make_prepared(col_meta, serializers=serializers) + bound = BoundStatement(prepared_statement=prepared) + bound.bind((None, UNSET_VALUE, 2)) + assert bound.values == [ + None, + UNSET_VALUE, + Int32Type.serialize(2, self.protocol_version), + ] From 5d684d1fd9345c4cfbc7327bdafbcdfb3c30f379 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 19 Mar 2026 16:48:55 +0200 Subject: [PATCH 4/5] (improvement) serializers: address PR review feedback - Chain original exception in _raise_bind_serialize_error (raise from exc) - Change _check_int32_range to raise struct.error instead of OverflowError, matching the behaviour of struct.pack('>i', value) - Clarify docstrings for _check_float_range/_check_int32_range - Expand _raise_bind_serialize_error docstring with specific exception types - Document __getitem__ requirement in vector serialize methods - Move io and uvint_pack imports to module scope in serializers.pyx - Add struct import to serializers.pyx for struct.error - Fix test_plain_path_overflow_error_wrapped docstring (struct.error, not OverflowError) - Update OverflowSerializer stub to raise struct.error - Replace name-mangled __serializers with _cached_serializers --- cassandra/query.py | 26 ++++++++++------ cassandra/serializers.pyx | 45 +++++++++++++++++----------- tests/unit/test_parameter_binding.py | 20 +++++++------ 3 files changed, 56 insertions(+), 35 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index eec42c8de9..645fd44612 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -552,16 +552,16 @@ def _serializers(self): if self.column_encryption_policy: return None try: - return self.__serializers + return self._cached_serializers except AttributeError: pass if _HAVE_CYTHON_SERIALIZERS and self.column_metadata: - self.__serializers = _cython_make_serializers( + self._cached_serializers = _cython_make_serializers( [col.type for col in self.column_metadata] ) else: - self.__serializers = None - return self.__serializers + self._cached_serializers = None + return self._cached_serializers @classmethod def from_message( @@ -663,13 +663,21 @@ def __str__(self): def _raise_bind_serialize_error(col_spec, value, exc): - """Wrap serialization errors with column context for all bind loop paths.""" + """Wrap TypeError, struct.error, or OverflowError with column context. + + Called from all three bind loop paths (CE, Cython, plain Python) to + provide a uniform error message that includes the column name and + expected type. struct.error arises from int32 out-of-range values; + OverflowError from float out-of-range values. Other exception types + (e.g. ValueError from VectorType dimension mismatch) propagate + without wrapping. + """ actual_type = type(value) message = ( 'Received an argument of invalid type for column "%s". ' "Expected: %s, Got: %s; (%s)" % (col_spec.name, col_spec.type, actual_type, exc) ) - raise TypeError(message) + raise TypeError(message) from exc class BoundStatement(Statement): @@ -826,7 +834,7 @@ def bind(self, values): else: col_bytes = col_spec.type.serialize(value, proto_version) self.values[idx] = col_bytes - # OverflowError: Cython int32/float casts may raise on out-of-range values + # struct.error: int32 out-of-range; OverflowError: float out-of-range except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) idx += 1 @@ -850,7 +858,7 @@ def bind(self, values): try: col_bytes = ser.serialize(value, proto_version) self.values[idx] = col_bytes - # OverflowError: Cython int32/float casts may raise on out-of-range values + # struct.error: int32 out-of-range; OverflowError: float out-of-range except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) idx += 1 @@ -871,7 +879,7 @@ def bind(self, values): try: col_bytes = col_spec.type.serialize(value, proto_version) self.values[idx] = col_bytes - # OverflowError: Cython int32/float casts may raise on out-of-range values + # struct.error: int32 out-of-range; OverflowError: float out-of-range except (TypeError, struct.error, OverflowError) as exc: _raise_bind_serialize_error(col_spec, value, exc) idx += 1 diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx index 0bf5ddf80b..6650487481 100644 --- a/cassandra/serializers.pyx +++ b/cassandra/serializers.pyx @@ -34,6 +34,9 @@ from libc.math cimport isinf, isnan from cpython.bytes cimport PyBytes_FromStringAndSize from cassandra import cqltypes +import io +import struct +from cassandra.marshal import uvint_pack cdef bint is_little_endian from cassandra.util import is_little_endian @@ -60,9 +63,9 @@ cdef class Serializer: cdef inline void _check_float_range(double value) except *: """Raise OverflowError for finite values outside float32 range. - This matches the behavior of struct.pack('>f', value), which raises - OverflowError (via struct.error) for values that cannot be represented - as a 32-bit IEEE 754 float. inf, -inf, and nan pass through unchanged. + Matches the behaviour of struct.pack('>f', value), which raises + OverflowError for values that cannot be represented as a 32-bit + IEEE 754 float. inf, -inf, and nan pass through unchanged. """ if not isinf(value) and not isnan(value): if value > FLT_MAX or value < -FLT_MAX: @@ -76,17 +79,16 @@ cdef inline void _check_float_range(double value) except *: # --------------------------------------------------------------------------- cdef inline void _check_int32_range(object value) except *: - """Raise OverflowError for values outside the signed int32 range. + """Raise struct.error for values outside the signed int32 range. - This matches the behavior of struct.pack('>i', value), which raises - struct.error for values outside [-2147483648, 2147483647]. The check - must be done on the Python int *before* the C-level cast, - which would silently truncate. + Matches the behaviour of struct.pack('>i', value), which raises + struct.error for out-of-range values. The check must be done on the + Python int *before* the C-level cast, which would silently + truncate. """ if value > 2147483647 or value < -2147483648: - raise OverflowError( - "Value %r out of range for int32 " - "(must be between -2147483648 and 2147483647)" % (value,) + raise struct.error( + "'i' format requires -2147483648 <= number <= 2147483647" ) @@ -222,7 +224,11 @@ cdef class SerVectorType(Serializer): return self._serialize_generic(value, protocol_version) cdef inline bytes _serialize_float(self, object values): - """Serialize a list of floats into a contiguous big-endian buffer.""" + """Serialize a sequence of floats into a contiguous big-endian buffer. + + Note: uses index-based access (values[i]) rather than iteration, so + the input must support __getitem__ (e.g. list or tuple). + """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 4 if buf_size == 0: @@ -255,7 +261,11 @@ cdef class SerVectorType(Serializer): free(buf) cdef inline bytes _serialize_double(self, object values): - """Serialize a list of doubles into a contiguous big-endian buffer.""" + """Serialize a sequence of doubles into a contiguous big-endian buffer. + + Note: uses index-based access (values[i]) rather than iteration, so + the input must support __getitem__ (e.g. list or tuple). + """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 8 if buf_size == 0: @@ -291,7 +301,11 @@ cdef class SerVectorType(Serializer): free(buf) cdef inline bytes _serialize_int32(self, object values): - """Serialize a list of int32 values into a contiguous big-endian buffer.""" + """Serialize a sequence of int32 values into a contiguous big-endian buffer. + + Note: uses index-based access (values[i]) rather than iteration, so + the input must support __getitem__ (e.g. list or tuple). + """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 4 if buf_size == 0: @@ -325,9 +339,6 @@ cdef class SerVectorType(Serializer): cdef inline bytes _serialize_generic(self, object values, int protocol_version): """Fallback: element-by-element Python serialization for non-optimized types.""" - import io - from cassandra.marshal import uvint_pack - serialized_size = self.subtype.serial_size() buf = io.BytesIO() for item in values: diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 8be29a9cd2..834e1701f1 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import struct import pytest from cassandra.encoder import Encoder @@ -244,20 +245,20 @@ def serialize(self, value, protocol_version): class OverflowSerializer: - """Stub that raises OverflowError, mimicking Cython cast overflow.""" + """Stub that raises struct.error, mimicking Cython int32 range check.""" def __init__(self, cqltype): self.cqltype = cqltype def serialize(self, value, protocol_version): - raise OverflowError("value too large to convert to int32_t") + raise struct.error("'i' format requires -2147483648 <= number <= 2147483647") class CythonBindPathTest(unittest.TestCase): """Tests for the Cython serializer fast path in BoundStatement.bind(). These tests inject stub serializers via the PreparedStatement's cached - __serializers attribute to exercise the Cython bind branch without + _cached_serializers attribute to exercise the Cython bind branch without requiring compiled Cython. """ @@ -275,9 +276,9 @@ def _make_prepared(self, column_metadata, serializers=None): result_metadata=None, result_metadata_id=None, ) - # Inject directly into the name-mangled cache attribute used by - # the _serializers property, bypassing the lazy initialization. - prepared._PreparedStatement__serializers = serializers + # Inject directly into the cache attribute used by the _serializers + # property, bypassing the lazy initialization. + prepared._cached_serializers = serializers return prepared def test_cython_path_normal_serialization(self): @@ -321,7 +322,7 @@ def test_cython_path_unset_value(self): assert bound.values[1] == UNSET_VALUE def test_cython_path_overflow_error_wrapped(self): - """OverflowError from Cython cast is caught and wrapped with column context.""" + """struct.error from Cython int32 range check is caught and wrapped with column context.""" column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] serializers = [OverflowSerializer(Int32Type)] prepared = self._make_prepared(column_metadata, serializers) @@ -348,7 +349,8 @@ def test_cython_path_type_error_wrapped(self): assert "Int32Type" in msg def test_plain_path_overflow_error_wrapped(self): - """OverflowError in the plain Python path is also caught and wrapped.""" + """Out-of-range int in the plain Python path raises struct.error (caught + alongside OverflowError) and is wrapped with column context.""" column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] # Force the plain Python path (no Cython serializers) prepared = self._make_prepared(column_metadata, serializers=None) @@ -383,7 +385,7 @@ def _make_prepared( result_metadata=None, result_metadata_id=None, ) - prepared._PreparedStatement__serializers = serializers + prepared._cached_serializers = serializers return prepared def _three_column_metadata(self): From 9aabe6b8f7cace44f505e710c0e81eb8da3b386c Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 20 Mar 2026 19:08:19 +0200 Subject: [PATCH 5/5] (improvement) query, serializers: address PR review feedback - Replace malloc/free with PyBytes_FromStringAndSize(NULL, ...) pattern in vector fast-paths, eliminating extra buffer copy and malloc(0) edge case - Change _check_int32_range to raise OverflowError instead of struct.error, consistent with _check_float_range - Remove unused imports (struct, malloc, free) - Differentiate error messages in _raise_bind_serialize_error: 'invalid type' for TypeError vs 'value out of range' for OverflowError/struct.error - Replace unused loop variable with while loop in UNSET_VALUE fill - Expand __getitem__ docstrings explaining performance rationale - Fix copyright header in test_parameter_binding.py --- cassandra/query.py | 28 +++-- cassandra/serializers.pyx | 160 +++++++++++++-------------- tests/unit/test_parameter_binding.py | 7 +- 3 files changed, 103 insertions(+), 92 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index 645fd44612..af9ceff0eb 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -547,7 +547,9 @@ def _serializers(self): The column_encryption_policy check is performed on every access (not cached) so that serializers are correctly bypassed if a policy is set - after construction. + after construction. This means the cache never goes stale: once a CE + policy is present, we always return None and fall through to the + encryption-aware bind path. """ if self.column_encryption_policy: return None @@ -667,15 +669,25 @@ def _raise_bind_serialize_error(col_spec, value, exc): Called from all three bind loop paths (CE, Cython, plain Python) to provide a uniform error message that includes the column name and - expected type. struct.error arises from int32 out-of-range values; - OverflowError from float out-of-range values. Other exception types - (e.g. ValueError from VectorType dimension mismatch) propagate - without wrapping. + expected type. The message distinguishes between wrong-type and + out-of-range scenarios: + + - TypeError -> "invalid type" + - OverflowError -> "value out of range" + - struct.error -> "value out of range" + + Other exception types (e.g. ValueError from VectorType dimension + mismatch) propagate without wrapping. """ actual_type = type(value) + if isinstance(exc, (OverflowError, struct.error)): + reason = "value out of range" + else: + reason = "invalid type" message = ( - 'Received an argument of invalid type for column "%s". ' - "Expected: %s, Got: %s; (%s)" % (col_spec.name, col_spec.type, actual_type, exc) + 'Received an argument with %s for column "%s". ' + "Expected: %s, Got: %s; (%s)" + % (reason, col_spec.name, col_spec.type, actual_type, exc) ) raise TypeError(message) from exc @@ -888,7 +900,7 @@ def bind(self, values): # Fill remaining unbound columns with UNSET_VALUE (v4+ feature). # The pre-allocated list already has slots for these, so index # assignment works directly without trimming first. - for i in range(idx, col_meta_len): + while idx < col_meta_len: idx = self._append_unset_value(idx) elif idx < col_meta_len: # Pre-v4: trim trailing unused slots (no UNSET_VALUE support) diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx index 6650487481..b02d0861cc 100644 --- a/cassandra/serializers.pyx +++ b/cassandra/serializers.pyx @@ -28,14 +28,12 @@ cqltype.serialize() classmethod. from libc.stdint cimport int32_t from libc.string cimport memcpy -from libc.stdlib cimport malloc, free from libc.float cimport FLT_MAX from libc.math cimport isinf, isnan -from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING from cassandra import cqltypes import io -import struct from cassandra.marshal import uvint_pack cdef bint is_little_endian @@ -79,15 +77,15 @@ cdef inline void _check_float_range(double value) except *: # --------------------------------------------------------------------------- cdef inline void _check_int32_range(object value) except *: - """Raise struct.error for values outside the signed int32 range. + """Raise OverflowError for values outside the signed int32 range. - Matches the behaviour of struct.pack('>i', value), which raises - struct.error for out-of-range values. The check must be done on the - Python int *before* the C-level cast, which would silently - truncate. + Mirrors ``_check_float_range``: we intentionally raise OverflowError + (not struct.error) so callers only need to catch one exception type + for out-of-range values. The check must be done on the Python int + *before* the C-level cast, which would silently truncate. """ if value > 2147483647 or value < -2147483648: - raise struct.error( + raise OverflowError( "'i' format requires -2147483648 <= number <= 2147483647" ) @@ -226,116 +224,116 @@ cdef class SerVectorType(Serializer): cdef inline bytes _serialize_float(self, object values): """Serialize a sequence of floats into a contiguous big-endian buffer. - Note: uses index-based access (values[i]) rather than iteration, so - the input must support __getitem__ (e.g. list or tuple). + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 4 if buf_size == 0: return b"" - cdef char *buf = malloc(buf_size) - if buf == NULL: - raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) cdef float val cdef char *src cdef char *dst - try: - for i in range(self.vector_size): - _check_float_range(values[i]) - val = values[i] - src = &val - dst = buf + i * 4 - - if is_little_endian: - dst[0] = src[3] - dst[1] = src[2] - dst[2] = src[1] - dst[3] = src[0] - else: - memcpy(dst, src, 4) - - return PyBytes_FromStringAndSize(buf, buf_size) - finally: - free(buf) + for i in range(self.vector_size): + _check_float_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result cdef inline bytes _serialize_double(self, object values): """Serialize a sequence of doubles into a contiguous big-endian buffer. - Note: uses index-based access (values[i]) rather than iteration, so - the input must support __getitem__ (e.g. list or tuple). + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 8 if buf_size == 0: return b"" - cdef char *buf = malloc(buf_size) - if buf == NULL: - raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) cdef double val cdef char *src cdef char *dst - try: - for i in range(self.vector_size): - val = values[i] - src = &val - dst = buf + i * 8 - - if is_little_endian: - dst[0] = src[7] - dst[1] = src[6] - dst[2] = src[5] - dst[3] = src[4] - dst[4] = src[3] - dst[5] = src[2] - dst[6] = src[1] - dst[7] = src[0] - else: - memcpy(dst, src, 8) - - return PyBytes_FromStringAndSize(buf, buf_size) - finally: - free(buf) + for i in range(self.vector_size): + val = values[i] + src = &val + dst = buf + i * 8 + + if is_little_endian: + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(dst, src, 8) + + return result cdef inline bytes _serialize_int32(self, object values): """Serialize a sequence of int32 values into a contiguous big-endian buffer. - Note: uses index-based access (values[i]) rather than iteration, so - the input must support __getitem__ (e.g. list or tuple). + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. """ cdef Py_ssize_t i cdef Py_ssize_t buf_size = self.vector_size * 4 if buf_size == 0: return b"" - cdef char *buf = malloc(buf_size) - if buf == NULL: - raise MemoryError("Failed to allocate %d bytes for vector serialization" % buf_size) + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) cdef int32_t val cdef char *src cdef char *dst - try: - for i in range(self.vector_size): - _check_int32_range(values[i]) - val = values[i] - src = &val - dst = buf + i * 4 - - if is_little_endian: - dst[0] = src[3] - dst[1] = src[2] - dst[2] = src[1] - dst[3] = src[0] - else: - memcpy(dst, src, 4) - - return PyBytes_FromStringAndSize(buf, buf_size) - finally: - free(buf) + for i in range(self.vector_size): + _check_int32_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result cdef inline bytes _serialize_generic(self, object values, int protocol_version): """Fallback: element-by-element Python serialization for non-optimized types.""" diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 834e1701f1..59c238169b 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -1,4 +1,4 @@ -# Copyright DataStax, Inc. +# Copyright ScyllaDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -349,8 +349,9 @@ def test_cython_path_type_error_wrapped(self): assert "Int32Type" in msg def test_plain_path_overflow_error_wrapped(self): - """Out-of-range int in the plain Python path raises struct.error (caught - alongside OverflowError) and is wrapped with column context.""" + """Out-of-range int in the plain Python path raises OverflowError (from + the Cython serializer) or struct.error (from the pure-Python fallback) + and is wrapped with column context.""" column_metadata = [ColumnMetadata("keyspace", "cf", "v0", Int32Type)] # Force the plain Python path (no Cython serializers) prepared = self._make_prepared(column_metadata, serializers=None)