diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..1d95d68130 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -717,31 +717,35 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) - rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] self.column_names = [c[2] for c in column_metadata] self.column_types = [c[3] for c in column_metadata] - col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) - - def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) - - try: - self.parsed_rows = [decode_row(row) for row in rows] - except Exception: - for row in rows: - for val, col_md, col_desc in zip(row, column_metadata, col_descs): - try: - decode_val(val, col_md, col_desc) - except Exception as e: - raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], - col_md[3].cql_parameterized_type(), - str(e))) + if not column_encryption_policy: + # Fast path: no column encryption — decode inline, skip ColDesc creation + self.parsed_rows = [ + _decode_row_inline(f, column_metadata, protocol_version) + for _ in range(rowcount) + ] + else: + # Slow path: column encryption enabled — need ColDesc and per-column CE check + rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] + try: + self.parsed_rows = [ + _decode_row_ce(row, column_metadata, col_descs, + column_encryption_policy, protocol_version) + for row in rows + ] + except Exception: + for row in rows: + for val, col_md, col_desc in zip(row, column_metadata, col_descs): + try: + _decode_val_ce(val, col_md, col_desc, + column_encryption_policy, protocol_version) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], + col_md[3].cql_parameterized_type(), + str(e))) def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): self.query_id = read_binary_string(f) @@ -1424,6 +1428,41 @@ def read_error_code_map(f): return error_code_map + +def _decode_row_inline(f, column_metadata, protocol_version): + """Decode a single row directly from the buffer (no column encryption).""" + row = [] + for col_md in column_metadata: + size = read_int(f) + if size < 0: + row.append(None) + else: + val = f.read(size) + try: + row.append(col_md[3].from_binary(val, protocol_version)) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], + col_md[3].cql_parameterized_type(), + str(e))) + return tuple(row) + + +def _decode_val_ce(val, col_md, col_desc, column_encryption_policy, protocol_version): + """Decode a single column value with column encryption support.""" + uses_ce = column_encryption_policy.contains_column(col_desc) + col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] + raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val + return col_type.from_binary(raw_bytes, protocol_version) + + +def _decode_row_ce(row, column_metadata, col_descs, column_encryption_policy, protocol_version): + """Decode a full row with column encryption support.""" + return tuple( + _decode_val_ce(val, col_md, col_desc, column_encryption_policy, protocol_version) + for val, col_md, col_desc in zip(row, column_metadata, col_descs) + ) + + def read_value(f): size = read_int(f) if size < 0: diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..9427cedb5a 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import unittest from unittest.mock import Mock -from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra import DriverException, ProtocolVersion, UnsupportedOperation, type_codes from cassandra.protocol import ( - PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, + PrepareMessage, QueryMessage, ExecuteMessage, ResultMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, RESULT_KIND_ROWS, write_int, write_short, write_string ) from cassandra.query import BatchType from cassandra.marshal import uint32_unpack @@ -31,6 +32,22 @@ class MessageTest(unittest.TestCase): + def test_result_message_wraps_inline_decode_errors(self): + body = io.BytesIO() + write_int(body, RESULT_KIND_ROWS) + write_int(body, 0) + write_int(body, 1) + write_string(body, "ks") + write_string(body, "tbl") + write_string(body, "v") + write_short(body, type_codes.DateType) + write_int(body, 1) + write_int(body, 1) + body.write(b"\x00") + + with pytest.raises(DriverException, match='Failed decoding result column "v"'): + ResultMessage.recv_body(io.BytesIO(body.getvalue()), ProtocolVersion.V4, 0, {}, None, None) + def test_prepare_message(self): """ Test to check the appropriate calls are made