diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 547a13c979..edd2eded9b 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -52,6 +52,7 @@ from cassandra import util _little_endian_flag = 1 # we always serialize LE +_INT32_NULL = int32_pack(-1) # pre-allocated null sentinel for collection serialization import ipaddress apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' @@ -841,7 +842,7 @@ def serialize_safe(cls, items, protocol_version): inner_proto = max(3, protocol_version) for item in items: if item is None: - buf.write(int32_pack(-1)) + buf.write(_INT32_NULL) else: itembytes = subtype.to_binary(item, inner_proto) buf.write(int32_pack(len(itembytes))) @@ -912,13 +913,13 @@ def serialize_safe(cls, themap, protocol_version): buf.write(int32_pack(len(keybytes))) buf.write(keybytes) else: - buf.write(int32_pack(-1)) + buf.write(_INT32_NULL) if val is not None: valbytes = value_type.to_binary(val, inner_proto) buf.write(int32_pack(len(valbytes))) buf.write(valbytes) else: - buf.write(int32_pack(-1)) + buf.write(_INT32_NULL) return buf.getvalue() @@ -964,7 +965,7 @@ def serialize_safe(cls, val, protocol_version): buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: - buf.write(int32_pack(-1)) + buf.write(_INT32_NULL) return buf.getvalue() @classmethod @@ -1041,7 +1042,7 @@ def serialize_safe(cls, val, protocol_version): buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: - buf.write(int32_pack(-1)) + buf.write(_INT32_NULL) return buf.getvalue() @classmethod diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 11aab2748d..b59d58e5f7 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -26,7 +26,7 @@ EmptyValue, LongType, SetType, UTF8Type, cql_typename, int8_pack, int64_pack, int64_unpack, lookup_casstype, lookup_casstype_simple, parse_casstype_args, - int32_pack, Int32Type, ListType, MapType, VectorType, + int32_pack, Int32Type, ListType, MapType, TupleType, UserType, VectorType, FloatType ) from cassandra.encoder import cql_quote @@ -1117,3 +1117,54 @@ def test_token_order(self): tokens_equal = [Token(1), Token(1)] check_sequence_consistency(tokens) check_sequence_consistency(tokens_equal, equal=True) + + +class CollectionNullSentinelTests(unittest.TestCase): + """ + Tests that collection types correctly round-trip None/null elements + using the pre-allocated _INT32_NULL sentinel. + """ + + def test_list_with_none_roundtrip(self): + proto = 4 + parameterized = ListType.apply_parameters([Int32Type]) + original = [1, None, 3] + serialized = parameterized.serialize(original, proto) + deserialized = parameterized.deserialize(serialized, proto) + self.assertEqual(deserialized, original) + + def test_set_with_none_roundtrip(self): + proto = 4 + parameterized = SetType.apply_parameters([Int32Type]) + original = [1, None, 3] + serialized = parameterized.serialize(original, proto) + deserialized = parameterized.deserialize(serialized, proto) + self.assertEqual(set(deserialized), set(original)) + + def test_map_with_none_values_roundtrip(self): + proto = 4 + parameterized = MapType.apply_parameters([Int32Type, Int32Type]) + original = {1: None, 2: 10} + serialized = parameterized.serialize(original, proto) + deserialized = parameterized.deserialize(serialized, proto) + self.assertEqual(dict(deserialized.items()), original) + + def test_tuple_with_none_roundtrip(self): + proto = 4 + parameterized = TupleType.apply_parameters([Int32Type, UTF8Type, Int32Type]) + original = (1, None, 3) + serialized = parameterized.serialize(original, proto) + deserialized = parameterized.deserialize(serialized, proto) + self.assertEqual(deserialized, original) + + def test_usertype_with_none_roundtrip(self): + proto = 4 + udt_class = UserType.make_udt_class( + 'test_ks', 'test_udt', + ('field_a', 'field_b', 'field_c'), + (Int32Type, UTF8Type, Int32Type) + ) + original = (1, None, 3) + serialized = udt_class.serialize(original, proto) + deserialized = udt_class.deserialize(serialized, proto) + self.assertEqual(deserialized, original)