diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8da9df6a55..75aeaa9435 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2998,9 +2998,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload, message = ExecuteMessage( prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, paging_state, timestamp, - skip_meta=bool(prepared_statement.result_metadata), - continuous_paging_options=continuous_paging_options, - result_metadata_id=prepared_statement.result_metadata_id) + continuous_paging_options=continuous_paging_options) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( @@ -4618,6 +4616,15 @@ def _query(self, host, message=None, cb=None): self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + if self.prepared_statement and isinstance(message, ExecuteMessage): + has_result_metadata_id = self.prepared_statement.result_metadata_id is not None + use_metadata_id = has_result_metadata_id and ( + ProtocolVersion.uses_prepared_metadata(connection.protocol_version) + or connection.features.use_metadata_id + ) + message.skip_meta = use_metadata_id + message.result_metadata_id = self.prepared_statement.result_metadata_id if use_metadata_id else None + if cb is None: cb = partial(self._set_result, host, connection, pool) @@ -4774,6 +4781,11 @@ def _set_result(self, host, connection, pool, response): self._paging_state = response.paging_state self._col_names = response.column_names self._col_types = response.column_types + new_result_metadata_id = getattr(response, 'result_metadata_id', None) + if self.prepared_statement and new_result_metadata_id is not None: + if response.column_metadata: + self.prepared_statement.result_metadata = response.column_metadata + self.prepared_statement.result_metadata_id = new_result_metadata_id if getattr(self.message, 'continuous_paging_options', None): self._handle_continuous_paging_first_response(connection, response) else: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..a0d08b82e3 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -573,6 +573,9 @@ def _write_query_params(self, f, protocol_version): if self.timestamp is not None: flags |= _PROTOCOL_TIMESTAMP_FLAG + if self.skip_meta: + flags |= _SKIP_METADATA_FLAG + if self.keyspace is not None: if ProtocolVersion.uses_keyspace_flag(protocol_version): flags |= _WITH_KEYSPACE_FLAG @@ -642,6 +645,8 @@ def send_body(self, f, protocol_version): write_string(f, self.query_id) if ProtocolVersion.uses_prepared_metadata(protocol_version): write_string(f, self.result_metadata_id) + elif self.result_metadata_id is not None: + write_string(f, self.result_metadata_id) self._write_query_params(f, protocol_version) @@ -745,7 +750,7 @@ def decode_row(row): def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): self.query_id = read_binary_string(f) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id: self.result_metadata_id = read_binary_string(f) else: self.result_metadata_id = None diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py index 877998be7d..5193d72a44 100644 --- a/cassandra/protocol_features.py +++ b/cassandra/protocol_features.py @@ -10,6 +10,7 @@ LWT_OPTIMIZATION_META_BIT_MASK = "LWT_OPTIMIZATION_META_BIT_MASK" RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR" TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1" +USE_METADATA_ID = "SCYLLA_USE_METADATA_ID" class ProtocolFeatures(object): rate_limit_error = None @@ -17,13 +18,16 @@ class ProtocolFeatures(object): sharding_info = None tablets_routing_v1 = False lwt_info = None + use_metadata_id = False - def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None): + def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None, + use_metadata_id=False): self.rate_limit_error = rate_limit_error self.shard_id = shard_id self.sharding_info = sharding_info self.tablets_routing_v1 = tablets_routing_v1 self.lwt_info = lwt_info + self.use_metadata_id = use_metadata_id @staticmethod def parse_from_supported(supported): @@ -31,7 +35,9 @@ def parse_from_supported(supported): shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported) tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported) lwt_info = ProtocolFeatures.parse_lwt_info(supported) - return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info) + use_metadata_id = ProtocolFeatures.parse_use_metadata_id(supported) + return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info, + use_metadata_id) @staticmethod def maybe_parse_rate_limit_error(supported): @@ -57,6 +63,8 @@ def add_startup_options(self, options): options[TABLETS_ROUTING_V1] = "" if self.lwt_info is not None: options[LWT_ADD_METADATA_MARK] = str(self.lwt_info.lwt_meta_bit_mask) + if self.use_metadata_id: + options[USE_METADATA_ID] = "" @staticmethod def parse_sharding_info(options): @@ -81,6 +89,10 @@ def parse_sharding_info(options): def parse_tablets_info(options): return TABLETS_ROUTING_V1 in options + @staticmethod + def parse_use_metadata_id(options): + return USE_METADATA_ID in options + @staticmethod def parse_lwt_info(options): value_list = options.get(LWT_ADD_METADATA_MARK, [None]) diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..17b474604b 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io +import struct import unittest from unittest.mock import Mock @@ -21,8 +23,10 @@ PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + _SKIP_METADATA_FLAG, + BatchMessage, ResultMessage ) +from cassandra.protocol_features import ProtocolFeatures from cassandra.query import BatchType from cassandra.marshal import uint32_unpack from cassandra.cluster import ContinuousPagingOptions @@ -68,6 +72,87 @@ def test_execute_message(self): (b'\x00\x04',), (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) + def test_execute_message_skip_meta_flag(self): + """skip_meta=True must set _SKIP_METADATA_FLAG (0x02) in the flags byte.""" + message = ExecuteMessage('1', [], 4, skip_meta=True) + mock_io = Mock() + + message.send_body(mock_io, 4) + # flags byte should be VALUES_FLAG | SKIP_METADATA_FLAG = 0x01 | 0x02 = 0x03 + self._check_calls(mock_io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x03',), (b'\x00\x00',)]) + + def test_execute_message_scylla_metadata_id_v4(self): + """result_metadata_id should be written on protocol v4 when set (Scylla extension).""" + message = ExecuteMessage('1', [], 4) + message.result_metadata_id = b'foo' + mock_io = Mock() + + message.send_body(mock_io, 4) + # metadata_id written before query params (same position as v5) + self._check_calls(mock_io, [(b'\x00\x01',), (b'1',), + (b'\x00\x03',), (b'foo',), + (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) + + def test_recv_results_prepared_scylla_extension_reads_metadata_id(self): + """ + When use_metadata_id is True (Scylla extension), result_metadata_id must be + read from the PREPARE response even for protocol v4. + """ + # Build a minimal valid PREPARE response binary (no bind/result columns): + # query_id: short(2) + b'ab' + # result_metadata_id: short(3) + b'xyz' <-- only present when extension active + # prepared flags: int(1) = global_tables_spec + # colcount: int(0) + # num_pk_indexes: int(0) + # ksname: short(2) + b'ks' + # cfname: short(2) + b'tb' + # result flags: int(4) = no_metadata + # result colcount: int(0) + buf = io.BytesIO( + struct.pack('>H', 2) + b'ab' # query_id + + struct.pack('>H', 3) + b'xyz' # result_metadata_id + + struct.pack('>i', 1) # prepared flags: global_tables_spec + + struct.pack('>i', 0) # colcount = 0 + + struct.pack('>i', 0) # num_pk_indexes = 0 + + struct.pack('>H', 2) + b'ks' # ksname + + struct.pack('>H', 2) + b'tb' # cfname + + struct.pack('>i', 4) # result flags: no_metadata + + struct.pack('>i', 0) # result colcount = 0 + ) + + features_with_extension = ProtocolFeatures(use_metadata_id=True) + msg = ResultMessage(kind=4) # RESULT_KIND_PREPARED = 4 + msg.recv_results_prepared(buf, protocol_version=4, + protocol_features=features_with_extension, + user_type_map={}) + assert msg.query_id == b'ab' + assert msg.result_metadata_id == b'xyz' + + def test_recv_results_prepared_no_extension_skips_metadata_id(self): + """ + Without use_metadata_id, result_metadata_id must NOT be read on protocol v4. + The buffer must NOT contain a metadata_id field. + """ + buf = io.BytesIO( + struct.pack('>H', 2) + b'ab' # query_id + # no result_metadata_id + + struct.pack('>i', 1) # prepared flags: global_tables_spec + + struct.pack('>i', 0) # colcount = 0 + + struct.pack('>i', 0) # num_pk_indexes = 0 + + struct.pack('>H', 2) + b'ks' # ksname + + struct.pack('>H', 2) + b'tb' # cfname + + struct.pack('>i', 4) # result flags: no_metadata + + struct.pack('>i', 0) # result colcount = 0 + ) + + features_without_extension = ProtocolFeatures(use_metadata_id=False) + msg = ResultMessage(kind=4) + msg.recv_results_prepared(buf, protocol_version=4, + protocol_features=features_without_extension, + user_type_map={}) + assert msg.query_id == b'ab' + assert msg.result_metadata_id is None + def test_query_message(self): """ Test to check the appropriate calls are made diff --git a/tests/unit/test_protocol_features.py b/tests/unit/test_protocol_features.py index 895c384f7e..387583680b 100644 --- a/tests/unit/test_protocol_features.py +++ b/tests/unit/test_protocol_features.py @@ -22,3 +22,38 @@ class OptionsHolder(object): assert protocol_features.rate_limit_error == 123 assert protocol_features.shard_id == 0 assert protocol_features.sharding_info is None + + def test_use_metadata_id_parsing(self): + """ + Test that SCYLLA_USE_METADATA_ID is parsed from SUPPORTED options. + """ + options = {'SCYLLA_USE_METADATA_ID': ['']} + protocol_features = ProtocolFeatures.parse_from_supported(options) + assert protocol_features.use_metadata_id is True + + def test_use_metadata_id_missing(self): + """ + Test that use_metadata_id is False when SCYLLA_USE_METADATA_ID is absent. + """ + options = {'SCYLLA_RATE_LIMIT_ERROR': ['ERROR_CODE=1']} + protocol_features = ProtocolFeatures.parse_from_supported(options) + assert protocol_features.use_metadata_id is False + + def test_use_metadata_id_startup_options(self): + """ + Test that SCYLLA_USE_METADATA_ID is included in STARTUP options when negotiated. + """ + options = {'SCYLLA_USE_METADATA_ID': ['']} + protocol_features = ProtocolFeatures.parse_from_supported(options) + startup = {} + protocol_features.add_startup_options(startup) + assert 'SCYLLA_USE_METADATA_ID' in startup + + def test_use_metadata_id_not_in_startup_when_not_negotiated(self): + """ + Test that SCYLLA_USE_METADATA_ID is NOT included in STARTUP when not negotiated. + """ + protocol_features = ProtocolFeatures.parse_from_supported({}) + startup = {} + protocol_features.add_startup_options(startup) + assert 'SCYLLA_USE_METADATA_ID' not in startup