Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2998,9 +2998,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload,
message = ExecuteMessage(
prepared_statement.query_id, query.values, cl,
serial_cl, fetch_size, paging_state, timestamp,
skip_meta=bool(prepared_statement.result_metadata),
continuous_paging_options=continuous_paging_options,
result_metadata_id=prepared_statement.result_metadata_id)
continuous_paging_options=continuous_paging_options)
elif isinstance(query, BatchStatement):
if self._protocol_version < 2:
raise UnsupportedOperation(
Expand Down Expand Up @@ -4618,6 +4616,15 @@ def _query(self, host, message=None, cb=None):
self._connection = connection
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []

if self.prepared_statement and isinstance(message, ExecuteMessage):
has_result_metadata_id = self.prepared_statement.result_metadata_id is not None
use_metadata_id = has_result_metadata_id and (
ProtocolVersion.uses_prepared_metadata(connection.protocol_version)
or connection.features.use_metadata_id
)
message.skip_meta = use_metadata_id
message.result_metadata_id = self.prepared_statement.result_metadata_id if use_metadata_id else None

if cb is None:
cb = partial(self._set_result, host, connection, pool)

Expand Down Expand Up @@ -4774,6 +4781,11 @@ def _set_result(self, host, connection, pool, response):
self._paging_state = response.paging_state
self._col_names = response.column_names
self._col_types = response.column_types
new_result_metadata_id = getattr(response, 'result_metadata_id', None)
if self.prepared_statement and new_result_metadata_id is not None:
if response.column_metadata:
self.prepared_statement.result_metadata = response.column_metadata
self.prepared_statement.result_metadata_id = new_result_metadata_id
if getattr(self.message, 'continuous_paging_options', None):
self._handle_continuous_paging_first_response(connection, response)
else:
Expand Down
7 changes: 6 additions & 1 deletion cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ def _write_query_params(self, f, protocol_version):
if self.timestamp is not None:
flags |= _PROTOCOL_TIMESTAMP_FLAG

if self.skip_meta:
flags |= _SKIP_METADATA_FLAG

if self.keyspace is not None:
if ProtocolVersion.uses_keyspace_flag(protocol_version):
flags |= _WITH_KEYSPACE_FLAG
Expand Down Expand Up @@ -642,6 +645,8 @@ def send_body(self, f, protocol_version):
write_string(f, self.query_id)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
write_string(f, self.result_metadata_id)
elif self.result_metadata_id is not None:
write_string(f, self.result_metadata_id)
self._write_query_params(f, protocol_version)


Expand Down Expand Up @@ -745,7 +750,7 @@ def decode_row(row):

def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
self.query_id = read_binary_string(f)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
self.result_metadata_id = read_binary_string(f)
else:
self.result_metadata_id = None
Expand Down
16 changes: 14 additions & 2 deletions cassandra/protocol_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,34 @@
LWT_OPTIMIZATION_META_BIT_MASK = "LWT_OPTIMIZATION_META_BIT_MASK"
RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"
USE_METADATA_ID = "SCYLLA_USE_METADATA_ID"

class ProtocolFeatures(object):
rate_limit_error = None
shard_id = 0
sharding_info = None
tablets_routing_v1 = False
lwt_info = None
use_metadata_id = False

def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None):
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None,
use_metadata_id=False):
self.rate_limit_error = rate_limit_error
self.shard_id = shard_id
self.sharding_info = sharding_info
self.tablets_routing_v1 = tablets_routing_v1
self.lwt_info = lwt_info
self.use_metadata_id = use_metadata_id

@staticmethod
def parse_from_supported(supported):
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
lwt_info = ProtocolFeatures.parse_lwt_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info)
use_metadata_id = ProtocolFeatures.parse_use_metadata_id(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info,
use_metadata_id)

@staticmethod
def maybe_parse_rate_limit_error(supported):
Expand All @@ -57,6 +63,8 @@ def add_startup_options(self, options):
options[TABLETS_ROUTING_V1] = ""
if self.lwt_info is not None:
options[LWT_ADD_METADATA_MARK] = str(self.lwt_info.lwt_meta_bit_mask)
if self.use_metadata_id:
options[USE_METADATA_ID] = ""

@staticmethod
def parse_sharding_info(options):
Expand All @@ -81,6 +89,10 @@ def parse_sharding_info(options):
def parse_tablets_info(options):
return TABLETS_ROUTING_V1 in options

@staticmethod
def parse_use_metadata_id(options):
return USE_METADATA_ID in options

@staticmethod
def parse_lwt_info(options):
value_list = options.get(LWT_ADD_METADATA_MARK, [None])
Expand Down
87 changes: 86 additions & 1 deletion tests/unit/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import struct
import unittest

from unittest.mock import Mock
Expand All @@ -21,8 +23,10 @@
PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation,
_PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG,
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
BatchMessage
_SKIP_METADATA_FLAG,
BatchMessage, ResultMessage
)
from cassandra.protocol_features import ProtocolFeatures
from cassandra.query import BatchType
from cassandra.marshal import uint32_unpack
from cassandra.cluster import ContinuousPagingOptions
Expand Down Expand Up @@ -68,6 +72,87 @@ def test_execute_message(self):
(b'\x00\x04',),
(b'\x00\x00\x00\x01',), (b'\x00\x00',)])

def test_execute_message_skip_meta_flag(self):
"""skip_meta=True must set _SKIP_METADATA_FLAG (0x02) in the flags byte."""
message = ExecuteMessage('1', [], 4, skip_meta=True)
mock_io = Mock()

message.send_body(mock_io, 4)
# flags byte should be VALUES_FLAG | SKIP_METADATA_FLAG = 0x01 | 0x02 = 0x03
self._check_calls(mock_io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x03',), (b'\x00\x00',)])

def test_execute_message_scylla_metadata_id_v4(self):
"""result_metadata_id should be written on protocol v4 when set (Scylla extension)."""
message = ExecuteMessage('1', [], 4)
message.result_metadata_id = b'foo'
mock_io = Mock()

message.send_body(mock_io, 4)
# metadata_id written before query params (same position as v5)
self._check_calls(mock_io, [(b'\x00\x01',), (b'1',),
(b'\x00\x03',), (b'foo',),
(b'\x00\x04',), (b'\x01',), (b'\x00\x00',)])

def test_recv_results_prepared_scylla_extension_reads_metadata_id(self):
"""
When use_metadata_id is True (Scylla extension), result_metadata_id must be
read from the PREPARE response even for protocol v4.
"""
# Build a minimal valid PREPARE response binary (no bind/result columns):
# query_id: short(2) + b'ab'
# result_metadata_id: short(3) + b'xyz' <-- only present when extension active
# prepared flags: int(1) = global_tables_spec
# colcount: int(0)
# num_pk_indexes: int(0)
# ksname: short(2) + b'ks'
# cfname: short(2) + b'tb'
# result flags: int(4) = no_metadata
# result colcount: int(0)
buf = io.BytesIO(
struct.pack('>H', 2) + b'ab' # query_id
+ struct.pack('>H', 3) + b'xyz' # result_metadata_id
+ struct.pack('>i', 1) # prepared flags: global_tables_spec
+ struct.pack('>i', 0) # colcount = 0
+ struct.pack('>i', 0) # num_pk_indexes = 0
+ struct.pack('>H', 2) + b'ks' # ksname
+ struct.pack('>H', 2) + b'tb' # cfname
+ struct.pack('>i', 4) # result flags: no_metadata
+ struct.pack('>i', 0) # result colcount = 0
)

features_with_extension = ProtocolFeatures(use_metadata_id=True)
msg = ResultMessage(kind=4) # RESULT_KIND_PREPARED = 4
msg.recv_results_prepared(buf, protocol_version=4,
protocol_features=features_with_extension,
user_type_map={})
assert msg.query_id == b'ab'
assert msg.result_metadata_id == b'xyz'

def test_recv_results_prepared_no_extension_skips_metadata_id(self):
"""
Without use_metadata_id, result_metadata_id must NOT be read on protocol v4.
The buffer must NOT contain a metadata_id field.
"""
buf = io.BytesIO(
struct.pack('>H', 2) + b'ab' # query_id
# no result_metadata_id
+ struct.pack('>i', 1) # prepared flags: global_tables_spec
+ struct.pack('>i', 0) # colcount = 0
+ struct.pack('>i', 0) # num_pk_indexes = 0
+ struct.pack('>H', 2) + b'ks' # ksname
+ struct.pack('>H', 2) + b'tb' # cfname
+ struct.pack('>i', 4) # result flags: no_metadata
+ struct.pack('>i', 0) # result colcount = 0
)

features_without_extension = ProtocolFeatures(use_metadata_id=False)
msg = ResultMessage(kind=4)
msg.recv_results_prepared(buf, protocol_version=4,
protocol_features=features_without_extension,
user_type_map={})
assert msg.query_id == b'ab'
assert msg.result_metadata_id is None

def test_query_message(self):
"""
Test to check the appropriate calls are made
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_protocol_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,38 @@ class OptionsHolder(object):
assert protocol_features.rate_limit_error == 123
assert protocol_features.shard_id == 0
assert protocol_features.sharding_info is None

def test_use_metadata_id_parsing(self):
"""
Test that SCYLLA_USE_METADATA_ID is parsed from SUPPORTED options.
"""
options = {'SCYLLA_USE_METADATA_ID': ['']}
protocol_features = ProtocolFeatures.parse_from_supported(options)
assert protocol_features.use_metadata_id is True

def test_use_metadata_id_missing(self):
"""
Test that use_metadata_id is False when SCYLLA_USE_METADATA_ID is absent.
"""
options = {'SCYLLA_RATE_LIMIT_ERROR': ['ERROR_CODE=1']}
protocol_features = ProtocolFeatures.parse_from_supported(options)
assert protocol_features.use_metadata_id is False

def test_use_metadata_id_startup_options(self):
"""
Test that SCYLLA_USE_METADATA_ID is included in STARTUP options when negotiated.
"""
options = {'SCYLLA_USE_METADATA_ID': ['']}
protocol_features = ProtocolFeatures.parse_from_supported(options)
startup = {}
protocol_features.add_startup_options(startup)
assert 'SCYLLA_USE_METADATA_ID' in startup

def test_use_metadata_id_not_in_startup_when_not_negotiated(self):
"""
Test that SCYLLA_USE_METADATA_ID is NOT included in STARTUP when not negotiated.
"""
protocol_features = ProtocolFeatures.parse_from_supported({})
startup = {}
protocol_features.add_startup_options(startup)
assert 'SCYLLA_USE_METADATA_ID' not in startup
Loading