diff --git a/pyiceberg/table/encryption.py b/pyiceberg/table/encryption.py new file mode 100644 index 0000000000..5fece598ca --- /dev/null +++ b/pyiceberg/table/encryption.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Encryption metadata plumbing only. + +This module models encrypted table metadata for faithful JSON round-tripping. +It does not implement cryptography, KMS integration, or key wrapping. +""" + +from __future__ import annotations + +from pydantic import Field + +from pyiceberg.typedef import IcebergBaseModel + + +class EncryptedKey(IcebergBaseModel): + key_id: str = Field(alias="key-id") + encrypted_key_metadata: str = Field(alias="encrypted-key-metadata") + encrypted_by_id: str | None = Field(alias="encrypted-by-id", default=None) + properties: dict[str, str] | None = Field(default=None) diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 26b6e3d3ad..16ae527032 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -28,6 +28,7 @@ from pyiceberg.exceptions import ValidationError from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.table.encryption import EncryptedKey from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry @@ -574,6 +575,9 @@ def construct_refs(self) -> TableMetadata: next_row_id: int | None = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: list[EncryptedKey] | None = Field(alias="encryption-keys", default=None) + """A list of encrypted keys used by this table.""" + def model_dump_json(self, exclude_none: bool = True, exclude: Any | None = None, by_alias: bool = True, **kwargs: Any) -> str: raise NotImplementedError("Writing V3 is not yet supported, see: https://github.com/apache/iceberg-python/issues/1551") diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..ef24a8dd44 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -252,6 +252,9 @@ class Snapshot(IcebergBaseModel): added_rows: int | None = Field( alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" ) + key_id: str | None = Field( + alias="key-id", default=None, description="ID of the encryption key that encrypts the manifest list key metadata" + ) def __str__(self) -> str: """Return the string representation of the Snapshot class.""" @@ -273,6 +276,7 @@ def __repr__(self) -> str: f"schema_id={self.schema_id}" if self.schema_id is not None else None, f"first_row_id={self.first_row_id}" if self.first_row_id is not None else None, f"added_rows={self.added_rows}" if self.added_rows is not None else None, + f"key_id='{self.key_id}'" if self.key_id is not None else None, ] filtered_fields = [field for field in fields if field is not None] return f"Snapshot({', '.join(filtered_fields)})" diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 3c98215366..451aef595f 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -221,6 +221,8 @@ def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> return BinaryType() if v == "unknown": return UnknownType() + if v == "variant": + return VariantType() if v.startswith("fixed"): return FixedType(_parse_fixed_type(v)) if v.startswith("decimal"): @@ -954,6 +956,25 @@ def minimum_format_version(self) -> TableVersion: return 3 +class VariantType(PrimitiveType): + """A variant data type in Iceberg can be represented using an instance of this class. + + Variants in Iceberg are semi-structured values encoded using the Parquet Variant binary format. + + Example: + >>> column_foo = VariantType() + >>> isinstance(column_foo, VariantType) + True + >>> column_foo + VariantType() + """ + + root: Literal["variant"] = Field(default="variant") + + def minimum_format_version(self) -> TableVersion: + return 3 + + class GeometryType(PrimitiveType): """A geometry data type in Iceberg (v3+) for storing spatial geometries. diff --git a/pyiceberg/utils/variant.py b/pyiceberg/utils/variant.py new file mode 100644 index 0000000000..213fc6ede6 --- /dev/null +++ b/pyiceberg/utils/variant.py @@ -0,0 +1,511 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pure-python encoder and decoder for the Parquet Variant binary format. + +This implements only the non-shredded (metadata + value) Variant encoding. Shredded +Variant and Arrow-native interop are not supported here; reading/writing Variant columns +through Arrow is blocked on Arrow issues #45937, #50131, and #50132. ``VariantType`` is +intentionally not registered with the Arrow schema visitors, so converting a schema that +contains a Variant to Arrow raises rather than silently producing wrong data. +""" + +from __future__ import annotations + +import struct +from datetime import date, datetime, timedelta, timezone +from decimal import Decimal +from typing import Any + +_METADATA_VERSION = 1 + +_EPOCH_DATE = date(1970, 1, 1) +_EPOCH_DATETIME_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) + +_BASIC_TYPE_PRIMITIVE = 0 +_BASIC_TYPE_SHORT_STRING = 1 +_BASIC_TYPE_OBJECT = 2 +_BASIC_TYPE_ARRAY = 3 + +_PRIMITIVE_NULL = 0 +_PRIMITIVE_TRUE = 1 +_PRIMITIVE_FALSE = 2 +_PRIMITIVE_INT8 = 3 +_PRIMITIVE_INT16 = 4 +_PRIMITIVE_INT32 = 5 +_PRIMITIVE_INT64 = 6 +_PRIMITIVE_DOUBLE = 7 +_PRIMITIVE_DECIMAL4 = 8 +_PRIMITIVE_DECIMAL8 = 9 +_PRIMITIVE_DECIMAL16 = 10 +_PRIMITIVE_DATE = 11 +_PRIMITIVE_TIMESTAMP_TZ = 12 +_PRIMITIVE_TIMESTAMP_NTZ = 13 +_PRIMITIVE_FLOAT = 14 +_PRIMITIVE_BINARY = 15 +_PRIMITIVE_STRING = 16 + +_MAX_DECIMAL_SCALE = 38 +_MAX_INT128 = (1 << 127) - 1 +_MIN_INT128 = -(1 << 127) + +_MAX_UINT32 = (1 << 32) - 1 +_MIN_INT8 = -(1 << 7) +_MAX_INT8 = (1 << 7) - 1 +_MIN_INT16 = -(1 << 15) +_MAX_INT16 = (1 << 15) - 1 +_MIN_INT32 = -(1 << 31) +_MAX_INT32 = (1 << 31) - 1 +_MIN_INT64 = -(1 << 63) +_MAX_INT64 = (1 << 63) - 1 + + +def encode_variant(py_value: Any) -> tuple[bytes, bytes]: + """Encode a Python value as an Iceberg/Parquet Variant binary pair. + + Args: + py_value: A variant value represented with native Python types. + + Returns: + A pair of metadata bytes and value bytes. + """ + field_names: set[str] = set() + _collect_field_names(py_value, field_names) + dictionary_strings = sorted(field_names) + dictionary = {name: index for index, name in enumerate(dictionary_strings)} + + return _encode_metadata(dictionary_strings), _encode_value(py_value, dictionary) + + +def decode_variant(metadata_bytes: bytes, value_bytes: bytes) -> Any: + """Decode an Iceberg/Parquet Variant binary pair into native Python values.""" + dictionary = _decode_metadata(metadata_bytes) + value, offset = _decode_value(value_bytes, 0, dictionary) + if offset != len(value_bytes): + raise ValueError("Variant value contains trailing bytes") + return value + + +def _collect_field_names(value: Any, field_names: set[str]) -> None: + if isinstance(value, dict): + for key, child in value.items(): + if not isinstance(key, str): + raise ValueError(f"Variant object field names must be strings: {key!r}") + field_names.add(key) + _collect_field_names(child, field_names) + elif isinstance(value, list): + for child in value: + _collect_field_names(child, field_names) + elif value is None or isinstance(value, (bool, int, float, str, bytes, Decimal, date, datetime)): + return + else: + raise ValueError(f"Unsupported variant value: {value!r}") + + +def _encode_metadata(dictionary_strings: list[str]) -> bytes: + encoded_strings = [value.encode("utf-8") for value in dictionary_strings] + dictionary_bytes = b"".join(encoded_strings) + offset_size = _unsigned_width(max(len(dictionary_strings), len(dictionary_bytes))) + header = _METADATA_VERSION | 0x10 | ((offset_size - 1) << 6) + + offsets = [0] + offset = 0 + for value in encoded_strings: + offset += len(value) + offsets.append(offset) + + return b"".join( + [ + bytes([header]), + _write_unsigned(len(dictionary_strings), offset_size), + *(_write_unsigned(offset, offset_size) for offset in offsets), + dictionary_bytes, + ] + ) + + +def _decode_metadata(metadata_bytes: bytes) -> list[str]: + if not metadata_bytes: + raise ValueError("Variant metadata is empty") + + header = metadata_bytes[0] + if header & 0x20: + raise ValueError("Variant metadata reserved bit is set") + version = header & 0x0F + if version != _METADATA_VERSION: + raise ValueError(f"Unsupported variant metadata version: {version}") + + offset_size = ((header >> 6) & 0x03) + 1 + offset = 1 + dictionary_size, offset = _read_unsigned(metadata_bytes, offset, offset_size) + + offsets = [] + for _ in range(dictionary_size + 1): + value, offset = _read_unsigned(metadata_bytes, offset, offset_size) + offsets.append(value) + + if not offsets or offsets[0] != 0: + raise ValueError("Variant metadata dictionary offsets must start with zero") + if any(left > right for left, right in zip(offsets, offsets[1:], strict=False)): + raise ValueError("Variant metadata dictionary offsets must be ordered") + + dictionary_bytes = metadata_bytes[offset:] + if offsets[-1] != len(dictionary_bytes): + raise ValueError("Variant metadata dictionary length does not match offsets") + + return [dictionary_bytes[start:end].decode("utf-8") for start, end in zip(offsets, offsets[1:], strict=False)] + + +def _encode_value(value: Any, dictionary: dict[str, int]) -> bytes: + if value is None: + return _primitive_header(_PRIMITIVE_NULL) + if value is True: + return _primitive_header(_PRIMITIVE_TRUE) + if value is False: + return _primitive_header(_PRIMITIVE_FALSE) + if isinstance(value, int) and not isinstance(value, bool): + return _encode_int(value) + if isinstance(value, float): + return _primitive_header(_PRIMITIVE_DOUBLE) + struct.pack(" bytes: + return bytes([(primitive_type << 2) | _BASIC_TYPE_PRIMITIVE]) + + +def _encode_int(value: int) -> bytes: + if _MIN_INT8 <= value <= _MAX_INT8: + return _primitive_header(_PRIMITIVE_INT8) + value.to_bytes(1, "little", signed=True) + if _MIN_INT16 <= value <= _MAX_INT16: + return _primitive_header(_PRIMITIVE_INT16) + value.to_bytes(2, "little", signed=True) + if _MIN_INT32 <= value <= _MAX_INT32: + return _primitive_header(_PRIMITIVE_INT32) + value.to_bytes(4, "little", signed=True) + if _MIN_INT64 <= value <= _MAX_INT64: + return _primitive_header(_PRIMITIVE_INT64) + value.to_bytes(8, "little", signed=True) + raise ValueError(f"Variant integer out of int64 range: {value}") + + +def _encode_string(value: str) -> bytes: + value_bytes = value.encode("utf-8") + if len(value_bytes) < 64: + return bytes([(len(value_bytes) << 2) | _BASIC_TYPE_SHORT_STRING]) + value_bytes + if len(value_bytes) > _MAX_UINT32: + raise ValueError("Variant string is too long") + return _primitive_header(_PRIMITIVE_STRING) + _write_unsigned(len(value_bytes), 4) + value_bytes + + +def _encode_binary(value: bytes) -> bytes: + if len(value) > _MAX_UINT32: + raise ValueError("Variant binary value is too long") + return _primitive_header(_PRIMITIVE_BINARY) + _write_unsigned(len(value), 4) + value + + +def _encode_decimal(value: Decimal) -> bytes: + sign, digits, exponent = value.as_tuple() + if not isinstance(exponent, int): + raise ValueError(f"Variant decimal cannot encode non-finite value: {value}") + scale = -exponent + if scale < 0 or scale > _MAX_DECIMAL_SCALE: + raise ValueError(f"Variant decimal scale must be in [0, 38]: {scale}") + unscaled = int("".join(str(digit) for digit in digits) or "0") + if sign: + unscaled = -unscaled + + if _MIN_INT32 <= unscaled <= _MAX_INT32: + primitive_type, width = _PRIMITIVE_DECIMAL4, 4 + elif _MIN_INT64 <= unscaled <= _MAX_INT64: + primitive_type, width = _PRIMITIVE_DECIMAL8, 8 + elif _MIN_INT128 <= unscaled <= _MAX_INT128: + primitive_type, width = _PRIMITIVE_DECIMAL16, 16 + else: + raise ValueError(f"Variant decimal unscaled value out of int128 range: {unscaled}") + + return _primitive_header(primitive_type) + bytes([scale]) + unscaled.to_bytes(width, "little", signed=True) + + +def _encode_timestamp(value: datetime) -> bytes: + if value.tzinfo is not None: + micros = round((value - _EPOCH_DATETIME_UTC).total_seconds() * 1_000_000) + primitive_type = _PRIMITIVE_TIMESTAMP_TZ + else: + micros = round((value - _EPOCH_DATETIME_UTC.replace(tzinfo=None)).total_seconds() * 1_000_000) + primitive_type = _PRIMITIVE_TIMESTAMP_NTZ + return _primitive_header(primitive_type) + micros.to_bytes(8, "little", signed=True) + + +def _encode_object(value: dict[Any, Any], dictionary: dict[str, int]) -> bytes: + items = [] + for key, child in value.items(): + if not isinstance(key, str): + raise ValueError(f"Variant object field names must be strings: {key!r}") + items.append((key, child)) + items.sort(key=lambda item: item[0]) + + encoded_values = [_encode_value(child, dictionary) for _, child in items] + value_region = b"".join(encoded_values) + field_offsets = _offsets(encoded_values) + field_ids = [dictionary[key] for key, _ in items] + + field_offset_size = _unsigned_width(len(value_region)) + field_id_size = _unsigned_width(max(field_ids, default=0)) + is_large = len(items) > 255 + header = _BASIC_TYPE_OBJECT | ((field_offset_size - 1) << 2) | ((field_id_size - 1) << 4) | (0x40 if is_large else 0) + + return b"".join( + [ + bytes([header]), + _write_unsigned(len(items), 4 if is_large else 1), + *(_write_unsigned(field_id, field_id_size) for field_id in field_ids), + *(_write_unsigned(offset, field_offset_size) for offset in field_offsets), + value_region, + ] + ) + + +def _encode_array(value: list[Any], dictionary: dict[str, int]) -> bytes: + encoded_values = [_encode_value(child, dictionary) for child in value] + value_region = b"".join(encoded_values) + field_offsets = _offsets(encoded_values) + + field_offset_size = _unsigned_width(len(value_region)) + is_large = len(value) > 255 + header = _BASIC_TYPE_ARRAY | ((field_offset_size - 1) << 2) | (0x10 if is_large else 0) + + return b"".join( + [ + bytes([header]), + _write_unsigned(len(value), 4 if is_large else 1), + *(_write_unsigned(offset, field_offset_size) for offset in field_offsets), + value_region, + ] + ) + + +def _offsets(encoded_values: list[bytes]) -> list[int]: + offsets = [0] + offset = 0 + for value in encoded_values: + offset += len(value) + offsets.append(offset) + return offsets + + +def _decode_value(value_bytes: bytes, offset: int, dictionary: list[str]) -> tuple[Any, int]: + if offset >= len(value_bytes): + raise ValueError("Unexpected end of variant value") + + metadata = value_bytes[offset] + offset += 1 + basic_type = metadata & 0x03 + value_header = metadata >> 2 + + if basic_type == _BASIC_TYPE_PRIMITIVE: + return _decode_primitive(value_header, value_bytes, offset) + if basic_type == _BASIC_TYPE_SHORT_STRING: + return _read_utf8(value_bytes, offset, value_header) + if basic_type == _BASIC_TYPE_OBJECT: + return _decode_object(metadata, value_bytes, offset, dictionary) + if basic_type == _BASIC_TYPE_ARRAY: + return _decode_array(metadata, value_bytes, offset, dictionary) + + raise ValueError(f"Unsupported variant basic type: {basic_type}") + + +def _decode_primitive(primitive_type: int, value_bytes: bytes, offset: int) -> tuple[Any, int]: + if primitive_type == _PRIMITIVE_NULL: + return None, offset + if primitive_type == _PRIMITIVE_TRUE: + return True, offset + if primitive_type == _PRIMITIVE_FALSE: + return False, offset + if primitive_type == _PRIMITIVE_INT8: + return _read_signed(value_bytes, offset, 1) + if primitive_type == _PRIMITIVE_INT16: + return _read_signed(value_bytes, offset, 2) + if primitive_type == _PRIMITIVE_INT32: + return _read_signed(value_bytes, offset, 4) + if primitive_type == _PRIMITIVE_INT64: + return _read_signed(value_bytes, offset, 8) + if primitive_type == _PRIMITIVE_DOUBLE: + _require_available(value_bytes, offset, 8) + return struct.unpack_from(" tuple[Decimal, int]: + _require_available(value_bytes, offset, 1) + scale = value_bytes[offset] + offset += 1 + if scale > _MAX_DECIMAL_SCALE: + raise ValueError(f"Variant decimal scale must be in [0, 38]: {scale}") + unscaled, offset = _read_signed(value_bytes, offset, width) + sign = 1 if unscaled < 0 else 0 + digits = tuple(int(digit) for digit in str(abs(unscaled))) + return Decimal((sign, digits, -scale)), offset + + +def _decode_object(metadata: int, value_bytes: bytes, offset: int, dictionary: list[str]) -> tuple[dict[str, Any], int]: + if metadata & 0x80: + raise ValueError("Variant object reserved bit is set") + + field_offset_size = ((metadata >> 2) & 0x03) + 1 + field_id_size = ((metadata >> 4) & 0x03) + 1 + is_large = bool(metadata & 0x40) + + num_elements, offset = _read_unsigned(value_bytes, offset, 4 if is_large else 1) + + field_ids = [] + for _ in range(num_elements): + field_id, offset = _read_unsigned(value_bytes, offset, field_id_size) + if field_id >= len(dictionary): + raise ValueError(f"Variant object field id out of range: {field_id}") + field_ids.append(field_id) + + field_offsets, offset = _read_offsets(value_bytes, offset, field_offset_size, num_elements) + value_region_start = offset + value_region_end = value_region_start + field_offsets[-1] + _require_available(value_bytes, value_region_start, field_offsets[-1]) + + result = {} + for index, field_id in enumerate(field_ids): + start = value_region_start + field_offsets[index] + expected_end = value_region_start + field_offsets[index + 1] + field_value, actual_end = _decode_value(value_bytes[:value_region_end], start, dictionary) + if actual_end != expected_end: + raise ValueError("Variant object field value length does not match offset") + result[dictionary[field_id]] = field_value + + return result, value_region_end + + +def _decode_array(metadata: int, value_bytes: bytes, offset: int, dictionary: list[str]) -> tuple[list[Any], int]: + if metadata & 0xE0: + raise ValueError("Variant array reserved bits are set") + + field_offset_size = ((metadata >> 2) & 0x03) + 1 + is_large = bool(metadata & 0x10) + + num_elements, offset = _read_unsigned(value_bytes, offset, 4 if is_large else 1) + field_offsets, offset = _read_offsets(value_bytes, offset, field_offset_size, num_elements) + + value_region_start = offset + value_region_end = value_region_start + field_offsets[-1] + _require_available(value_bytes, value_region_start, field_offsets[-1]) + + result = [] + for index in range(num_elements): + start = value_region_start + field_offsets[index] + expected_end = value_region_start + field_offsets[index + 1] + item, actual_end = _decode_value(value_bytes[:value_region_end], start, dictionary) + if actual_end != expected_end: + raise ValueError("Variant array element value length does not match offset") + result.append(item) + + return result, value_region_end + + +def _read_offsets(value_bytes: bytes, offset: int, offset_size: int, num_elements: int) -> tuple[list[int], int]: + offsets = [] + for _ in range(num_elements + 1): + value, offset = _read_unsigned(value_bytes, offset, offset_size) + offsets.append(value) + + if not offsets or offsets[0] != 0: + raise ValueError("Variant offsets must start with zero") + if any(left > right for left, right in zip(offsets, offsets[1:], strict=False)): + raise ValueError("Variant offsets must be ordered") + return offsets, offset + + +def _read_utf8(value_bytes: bytes, offset: int, length: int) -> tuple[str, int]: + _require_available(value_bytes, offset, length) + return value_bytes[offset : offset + length].decode("utf-8"), offset + length + + +def _read_signed(value_bytes: bytes, offset: int, width: int) -> tuple[int, int]: + _require_available(value_bytes, offset, width) + return int.from_bytes(value_bytes[offset : offset + width], "little", signed=True), offset + width + + +def _read_unsigned(value_bytes: bytes, offset: int, width: int) -> tuple[int, int]: + _require_available(value_bytes, offset, width) + return int.from_bytes(value_bytes[offset : offset + width], "little", signed=False), offset + width + + +def _write_unsigned(value: int, width: int) -> bytes: + if value < 0 or value > (1 << (width * 8)) - 1: + raise ValueError(f"Value {value} does not fit in {width} bytes") + return value.to_bytes(width, "little", signed=False) + + +def _require_available(value_bytes: bytes, offset: int, length: int) -> None: + if offset < 0 or length < 0 or offset + length > len(value_bytes): + raise ValueError("Unexpected end of variant value") + + +def _unsigned_width(value: int) -> int: + if value < 0: + raise ValueError(f"Negative values cannot be encoded as unsigned offsets: {value}") + if value <= 0xFF: + return 1 + if value <= 0xFFFF: + return 2 + if value <= 0xFFFFFF: + return 3 + if value <= _MAX_UINT32: + return 4 + raise ValueError(f"Value is too large for a four-byte unsigned field: {value}") diff --git a/tests/table/test_encryption_metadata.py b/tests/table/test_encryption_metadata.py new file mode 100644 index 0000000000..026be0af85 --- /dev/null +++ b/tests/table/test_encryption_metadata.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import pytest +from pydantic import ValidationError + +from pyiceberg.table.encryption import EncryptedKey +from pyiceberg.table.metadata import TableMetadataV3 +from pyiceberg.table.snapshots import Snapshot + + +@pytest.mark.parametrize( + "payload", + [ + {}, + {"key-id": "key-a"}, + {"encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh"}, + ], +) +def test_encrypted_key_requires_key_id_and_metadata(payload: dict[str, Any]) -> None: + with pytest.raises(ValidationError): + EncryptedKey.model_validate(payload) + + +def test_encrypted_key_deserialization_with_all_fields() -> None: + encrypted_key = EncryptedKey.model_validate( + { + "key-id": "key-a", + "encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh", + "encrypted-by-id": "root-key", + "properties": {"kms": "test", "purpose": "table"}, + } + ) + + assert encrypted_key.key_id == "key-a" + assert encrypted_key.encrypted_key_metadata == "ZW5jcnlwdGVkLW1ldGFkYXRh" + assert encrypted_key.encrypted_by_id == "root-key" + assert encrypted_key.properties == {"kms": "test", "purpose": "table"} + + +def test_encrypted_key_deserialization_with_required_fields_only() -> None: + encrypted_key = EncryptedKey.model_validate( + { + "key-id": "key-a", + "encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh", + } + ) + + assert encrypted_key.key_id == "key-a" + assert encrypted_key.encrypted_key_metadata == "ZW5jcnlwdGVkLW1ldGFkYXRh" + assert encrypted_key.encrypted_by_id is None + assert encrypted_key.properties is None + + +def test_encrypted_key_serialization_round_trip_uses_aliases() -> None: + encrypted_key = EncryptedKey.model_validate( + { + "key-id": "key-a", + "encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh", + "encrypted-by-id": "root-key", + "properties": {"kms": "test"}, + } + ) + + serialized = encrypted_key.model_dump_json(by_alias=True) + + assert EncryptedKey.model_validate_json(serialized) == encrypted_key + assert '"key-id"' in serialized + assert '"encrypted-key-metadata"' in serialized + assert "key_id" not in serialized + assert "encrypted_key_metadata" not in serialized + + +def test_encrypted_key_ignores_unknown_fields() -> None: + encrypted_key = EncryptedKey.model_validate( + { + "key-id": "key-a", + "encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh", + "some-future-field": "ignored", + } + ) + + assert encrypted_key.key_id == "key-a" + assert not hasattr(encrypted_key, "some_future_field") + + +def test_snapshot_key_id_deserialization_and_serialization() -> None: + snapshot = Snapshot.model_validate( + { + "snapshot-id": 25, + "timestamp-ms": 1602638573590, + "manifest-list": "s3:/a/b/c.avro", + "key-id": "manifest-list-key", + } + ) + snapshot_without_key = Snapshot.model_validate( + { + "snapshot-id": 26, + "timestamp-ms": 1602638573591, + "manifest-list": "s3:/a/b/d.avro", + } + ) + + assert snapshot.key_id == "manifest-list-key" + assert snapshot_without_key.key_id is None + assert '"key-id":"manifest-list-key"' in snapshot.model_dump_json(by_alias=True) + + +def test_table_metadata_v3_encryption_keys_deserialization(example_table_metadata_v3: dict[str, Any]) -> None: + metadata_dict = deepcopy(example_table_metadata_v3) + metadata_dict["encryption-keys"] = [ + { + "key-id": "key-a", + "encrypted-key-metadata": "ZW5jcnlwdGVkLW1ldGFkYXRh", + } + ] + + metadata = TableMetadataV3(**metadata_dict) + + assert metadata.encryption_keys is not None + assert metadata.encryption_keys[0].key_id == "key-a" + assert metadata.encryption_keys[0].encrypted_key_metadata == "ZW5jcnlwdGVkLW1ldGFkYXRh" + + +def test_table_metadata_v3_without_encryption_keys_deserialization(example_table_metadata_v3: dict[str, Any]) -> None: + metadata = TableMetadataV3(**deepcopy(example_table_metadata_v3)) + + assert metadata.encryption_keys is None diff --git a/tests/utils/test_variant.py b/tests/utils/test_variant.py new file mode 100644 index 0000000000..3343c34171 --- /dev/null +++ b/tests/utils/test_variant.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import struct +from datetime import date, datetime, timezone +from decimal import Decimal +from typing import Any + +import pytest + +from pyiceberg.types import IcebergType, NestedField, VariantType +from pyiceberg.utils.variant import _PRIMITIVE_FLOAT, decode_variant, encode_variant + + +@pytest.mark.parametrize( + "value", + [ + None, + True, + False, + 0, + 1, + -1, + 127, + 128, + -129, + 32767, + 70000, + 5_000_000_000, + 3.14, + "", + "hi", + "x" * 70, + {"a": 1, "b": "x", "c": True}, + {"outer": {"inner": 42}, "list": [1, 2, 3]}, + {}, + [1, "two", None, {"k": 3.0}], + [], + Decimal("0"), + Decimal("1.50"), + Decimal("-3.14159"), + Decimal("123456789012345.678"), + Decimal("-9." + "9" * 37), + b"", + b"\x00\x01\x02bytes", + date(2024, 1, 15), + date(1969, 12, 31), + datetime(2024, 1, 15, 12, 30, 45, 123456), + datetime(2024, 1, 15, 12, 30, 45, 123456, tzinfo=timezone.utc), + ], +) +def test_variant_round_trip(value: Any) -> None: + metadata_bytes, value_bytes = encode_variant(value) + + assert metadata_bytes[0] & 0x0F == 1 + decoded = decode_variant(metadata_bytes, value_bytes) + assert decoded == value + if isinstance(value, Decimal): + # Scale must survive the round trip, not just numeric equality. + assert decoded.as_tuple() == value.as_tuple() + + +def test_short_string_uses_short_string_basic_type() -> None: + _, value_bytes = encode_variant("hi") + + assert value_bytes[0] & 0x03 == 1 + + +def test_long_string_uses_primitive_string() -> None: + _, value_bytes = encode_variant("x" * 70) + + assert value_bytes[0] & 0x03 == 0 + assert value_bytes[0] >> 2 == 16 + + +def test_object_field_ids_are_emitted_in_lexicographic_field_name_order() -> None: + metadata_bytes, value_bytes = encode_variant({"c": 3, "a": 1, "b": 2}) + dictionary = _decode_metadata_dictionary(metadata_bytes) + + metadata = value_bytes[0] + assert metadata & 0x03 == 2 + is_large = bool(metadata & 0x40) + assert not is_large + field_id_size = ((metadata >> 4) & 0x03) + 1 + num_elements = value_bytes[1] + field_ids_start = 2 + field_ids = [ + int.from_bytes(value_bytes[offset : offset + field_id_size], "little") + for offset in range(field_ids_start, field_ids_start + num_elements * field_id_size, field_id_size) + ] + + assert [dictionary[field_id] for field_id in field_ids] == ["a", "b", "c"] + + +@pytest.mark.parametrize( + "value, expected_value_bytes", + [ + # null: basic_type=0 (primitive), value_header=0 -> 0x00 + (None, b"\x00"), + # true: primitive type 1 -> (1 << 2) | 0 = 0x04 + (True, b"\x04"), + # false: primitive type 2 -> (2 << 2) | 0 = 0x08 + (False, b"\x08"), + # int8 1: header (3 << 2) | 0 = 0x0c, then 0x01 + (1, b"\x0c\x01"), + # int8 -1: 0x0c then 0xff (two's complement) + (-1, b"\x0c\xff"), + # int16 200: header (4 << 2) = 0x10, then 0xc8 0x00 + (200, b"\x10\xc8\x00"), + # short string "hi": header (2 << 2) | 1 = 0x09, then bytes + ("hi", b"\x09hi"), + # empty short string: header (0 << 2) | 1 = 0x01 + ("", b"\x01"), + # double 1.0: header (7 << 2) = 0x1c, then IEEE-754 LE + (1.0, b"\x1c\x00\x00\x00\x00\x00\x00\xf0\x3f"), + # decimal4 1.50: header (8 << 2) = 0x20, scale=2, unscaled=150 (LE int32) + (Decimal("1.50"), b"\x20\x02\x96\x00\x00\x00"), + # date epoch+0 (1970-01-01): header (11 << 2) = 0x2c, then 0 (LE int32) + (date(1970, 1, 1), b"\x2c\x00\x00\x00\x00"), + # binary b"ab": header (15 << 2) = 0x3c, length 2 (LE uint32), then bytes + (b"ab", b"\x3c\x02\x00\x00\x00ab"), + ], +) +def test_variant_value_bytes_match_spec(value: Any, expected_value_bytes: bytes) -> None: + _, value_bytes = encode_variant(value) + + assert value_bytes == expected_value_bytes + assert decode_variant(b"\x11\x00\x00", value_bytes) == value + + +def test_scalar_metadata_header_is_empty_dictionary() -> None: + # version=1, sorted_strings=1, offset_size=1 -> 0x11; dict size 0; single offset 0. + metadata_bytes, _ = encode_variant(42) + + assert metadata_bytes == b"\x11\x00\x00" + + +def test_metadata_header_offset_size_uses_top_two_bits() -> None: + # Force a dictionary large enough (>255 bytes) to require a 2-byte offset size. + value = {f"field_name_{index:03d}": index for index in range(30)} + + metadata_bytes, _ = encode_variant(value) + + header = metadata_bytes[0] + assert header & 0x0F == 1 # version + assert (header >> 4) & 0x01 == 1 # sorted_strings + assert (header >> 5) & 0x01 == 0 # reserved bit must stay 0 + assert ((header >> 6) & 0x03) + 1 == 2 # offset_size lives in bits 6-7 + assert decode_variant(metadata_bytes, encode_variant(value)[1]) == value + + +def test_large_dictionary_metadata_round_trips() -> None: + # A dictionary whose byte length exceeds 0xFFFF forces offset_size 3, which sets + # bit 7 of the metadata header. Decoding must not mistake that for the reserved bit. + value = {f"field_{index:020d}": index for index in range(4000)} + + metadata_bytes, value_bytes = encode_variant(value) + + assert ((metadata_bytes[0] >> 6) & 0x03) + 1 >= 3 + assert decode_variant(metadata_bytes, value_bytes) == value + + +def test_decode_float_primitive() -> None: + # Variant float (primitive id 14) is read for interop, even though encoding emits + # double for native Python floats. header = (14 << 2) = 0x38, then IEEE-754 LE float. + value_bytes = bytes([_PRIMITIVE_FLOAT << 2]) + struct.pack(" None: + pytest.importorskip("pyarrow") + from pyiceberg.io.pyarrow import schema_to_pyarrow + from pyiceberg.schema import Schema + + schema = Schema(NestedField(1, "payload", VariantType(), required=False)) + + # Arrow-native / shredded Variant is intentionally unsupported; it must error, not + # silently emit a wrong Arrow type. + with pytest.raises(ValueError, match="Type not recognized: variant"): + schema_to_pyarrow(schema) + + +def test_variant_type() -> None: + field = NestedField(1, "payload", VariantType(), required=False) + + assert str(VariantType()) == "variant" + assert VariantType().minimum_format_version() == 3 + assert VariantType().model_dump_json() == '"variant"' + assert VariantType.model_validate_json('"variant"') == VariantType() + assert IcebergType.model_validate("variant") == VariantType() + assert NestedField.model_validate_json(field.model_dump_json()) == field + + +def _decode_metadata_dictionary(metadata_bytes: bytes) -> list[str]: + header = metadata_bytes[0] + offset_size = ((header >> 6) & 0x03) + 1 + offset = 1 + dictionary_size = int.from_bytes(metadata_bytes[offset : offset + offset_size], "little") + offset += offset_size + + offsets = [] + for _ in range(dictionary_size + 1): + offsets.append(int.from_bytes(metadata_bytes[offset : offset + offset_size], "little")) + offset += offset_size + + dictionary_bytes = metadata_bytes[offset:] + return [dictionary_bytes[start:end].decode("utf-8") for start, end in zip(offsets, offsets[1:], strict=False)]