diff --git a/mabel/data/writers/internals/blob_writer.py b/mabel/data/writers/internals/blob_writer.py index 43653bf..feb1d49 100644 --- a/mabel/data/writers/internals/blob_writer.py +++ b/mabel/data/writers/internals/blob_writer.py @@ -1,7 +1,7 @@ import io import json import threading -from typing import Optional +from typing import List, Optional, Union import orjson import orso @@ -32,13 +32,15 @@ def __init__( format: str = "parquet", schema: Optional[RelationSchema] = None, parquet_row_group_size: int = 5000, - sort_by: Optional[str] = None, + sort_by: Optional[Union[str, List]] = None, + use_dictionary: Optional[Union[bool, List[str]]] = None, **kwargs, ): self.format = format self.maximum_blob_size = blob_size self.parquet_row_group_size = parquet_row_group_size self.sort_by = sort_by + self.use_dictionary = use_dictionary if format not in SUPPORTED_FORMATS_ALGORITHMS: raise ValueError( @@ -166,12 +168,18 @@ def commit(self): # sort the table if sort_by is specified if self.sort_by: - pytable = pytable.sort_by(self.sort_by) + # Convert list of strings to PyArrow format + sort_spec = self.sort_by + if isinstance(self.sort_by, list) and all(isinstance(item, str) for item in self.sort_by): + # Convert list of strings to list of tuples with default ascending order + sort_spec = [(col, "ascending") for col in self.sort_by] + pytable = pytable.sort_by(sort_spec) tempfile = io.BytesIO() - pyarrow.parquet.write_table( - pytable, where=tempfile, row_group_size=self.parquet_row_group_size - ) + write_kwargs = {"row_group_size": self.parquet_row_group_size} + if self.use_dictionary is not None: + write_kwargs["use_dictionary"] = self.use_dictionary + pyarrow.parquet.write_table(pytable, where=tempfile, **write_kwargs) tempfile.seek(0) write_buffer = tempfile.read() diff --git a/mabel/version.py b/mabel/version.py index a1d0e70..9e24bee 100644 --- a/mabel/version.py +++ b/mabel/version.py @@ -1,6 +1,6 @@ # Store the version here so: # 1) we don't load dependencies by storing it in __init__.py # 2) we can import it in setup.py for the same reason -__version__ = "0.6.28" +__version__ = "0.6.29" # nodoc - don't add to the documentation wiki diff --git a/tests/test_writer_parquet_features.py b/tests/test_writer_parquet_features.py index 856e851..8b8cfd1 100644 --- a/tests/test_writer_parquet_features.py +++ b/tests/test_writer_parquet_features.py @@ -148,6 +148,194 @@ def test_parquet_default_row_group_size(): shutil.rmtree("_temp_default", ignore_errors=True) +def test_parquet_sorting_list_of_strings(): + """Test that sort_by parameter can accept a list of strings""" + shutil.rmtree("_temp_sort_list", ignore_errors=True) + + # Create a writer with sorting by list of strings + w = BatchWriter( + inner_writer=DiskWriter, + dataset="_temp_sort_list", + format="parquet", + date=datetime.datetime.utcnow().date(), + schema=[ + {"name": "id", "type": "INTEGER"}, + {"name": "category", "type": "VARCHAR"}, + {"name": "value", "type": "VARCHAR"} + ], + sort_by=["category", "id"], # Sort by category first, then id + parquet_row_group_size=5000, + ) + + # Write records in random order + records_to_write = [ + {"id": 3, "category": "B", "value": "value_3"}, + {"id": 1, "category": "A", "value": "value_1"}, + {"id": 4, "category": "B", "value": "value_4"}, + {"id": 2, "category": "A", "value": "value_2"}, + {"id": 5, "category": "C", "value": "value_5"}, + ] + + for record in records_to_write: + w.append(record) + + w.finalize() + + # Read back and verify the data is sorted by category, then id + r = Reader(inner_reader=DiskReader, dataset="_temp_sort_list") + records = list(r) + + assert len(records) == 5, f"Expected 5 records, got {len(records)}" + + # Check that records are sorted by category first, then by id + expected_order = [ + {"id": 1, "category": "A", "value": "value_1"}, + {"id": 2, "category": "A", "value": "value_2"}, + {"id": 3, "category": "B", "value": "value_3"}, + {"id": 4, "category": "B", "value": "value_4"}, + {"id": 5, "category": "C", "value": "value_5"}, + ] + + for i, record in enumerate(records): + assert record["id"] == expected_order[i]["id"], f"Record {i} id mismatch: {record['id']} != {expected_order[i]['id']}" + assert record["category"] == expected_order[i]["category"], f"Record {i} category mismatch" + + shutil.rmtree("_temp_sort_list", ignore_errors=True) + + +def test_parquet_sorting_single_column_list(): + """Test that sort_by parameter can accept a list with a single string""" + shutil.rmtree("_temp_sort_single_list", ignore_errors=True) + + # Create a writer with sorting by a list containing a single column + w = BatchWriter( + inner_writer=DiskWriter, + dataset="_temp_sort_single_list", + format="parquet", + date=datetime.datetime.utcnow().date(), + schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}], + sort_by=["id"], # Sort by id column as a list + parquet_row_group_size=5000, + ) + + # Write records in reverse order + for i in range(10, 0, -1): + w.append({"id": i, "value": f"value_{i}"}) + + w.finalize() + + # Read back and verify the data is sorted + r = Reader(inner_reader=DiskReader, dataset="_temp_sort_single_list") + records = list(r) + + assert len(records) == 10, f"Expected 10 records, got {len(records)}" + + # Check that records are sorted by id + ids = [record["id"] for record in records] + assert ids == list(range(1, 11)), f"Records are not sorted correctly: {ids}" + + shutil.rmtree("_temp_sort_single_list", ignore_errors=True) + + +def test_parquet_dictionary_encoding_all(): + """Test that use_dictionary parameter can be set to True for all columns""" + shutil.rmtree("_temp_dict_all", ignore_errors=True) + + w = BatchWriter( + inner_writer=DiskWriter, + dataset="_temp_dict_all", + format="parquet", + date=datetime.datetime.utcnow().date(), + schema=[ + {"name": "id", "type": "INTEGER"}, + {"name": "category", "type": "VARCHAR"} + ], + use_dictionary=True, # Enable dictionary encoding for all columns + ) + + # Write records with repeated category values (good for dictionary encoding) + for i in range(100): + w.append({"id": i, "category": f"category_{i % 5}"}) + + w.finalize() + + # Read back and verify data + r = Reader(inner_reader=DiskReader, dataset="_temp_dict_all") + records = list(r) + assert len(records) == 100, f"Expected 100 records, got {len(records)}" + + shutil.rmtree("_temp_dict_all", ignore_errors=True) + + +def test_parquet_dictionary_encoding_disabled(): + """Test that use_dictionary parameter can be set to False to disable dictionary encoding""" + shutil.rmtree("_temp_dict_disabled", ignore_errors=True) + + w = BatchWriter( + inner_writer=DiskWriter, + dataset="_temp_dict_disabled", + format="parquet", + date=datetime.datetime.utcnow().date(), + schema=[ + {"name": "id", "type": "INTEGER"}, + {"name": "category", "type": "VARCHAR"} + ], + use_dictionary=False, # Disable dictionary encoding + ) + + # Write records + for i in range(100): + w.append({"id": i, "category": f"category_{i % 5}"}) + + w.finalize() + + # Read back and verify data + r = Reader(inner_reader=DiskReader, dataset="_temp_dict_disabled") + records = list(r) + assert len(records) == 100, f"Expected 100 records, got {len(records)}" + + shutil.rmtree("_temp_dict_disabled", ignore_errors=True) + + +def test_parquet_dictionary_encoding_specific_columns(): + """Test that use_dictionary parameter can specify specific columns for dictionary encoding""" + shutil.rmtree("_temp_dict_specific", ignore_errors=True) + + w = BatchWriter( + inner_writer=DiskWriter, + dataset="_temp_dict_specific", + format="parquet", + date=datetime.datetime.utcnow().date(), + schema=[ + {"name": "id", "type": "INTEGER"}, + {"name": "category", "type": "VARCHAR"}, + {"name": "value", "type": "VARCHAR"} + ], + use_dictionary=["category"], # Only encode 'category' column with dictionary + ) + + # Write records with repeated category values + for i in range(100): + w.append({ + "id": i, + "category": f"category_{i % 5}", + "value": f"unique_value_{i}" + }) + + w.finalize() + + # Read back and verify data + r = Reader(inner_reader=DiskReader, dataset="_temp_dict_specific") + records = list(r) + assert len(records) == 100, f"Expected 100 records, got {len(records)}" + + # Verify the data is correct + assert records[0]["category"] == "category_0" + assert records[50]["category"] == "category_0" + + shutil.rmtree("_temp_dict_specific", ignore_errors=True) + + if __name__ == "__main__": # pragma: no cover from tests.helpers.runner import run_tests