Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions mabel/data/writers/internals/blob_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import json
import threading
from typing import Optional
from typing import List, Optional, Union

import orjson
import orso
Expand Down Expand Up @@ -32,13 +32,15 @@ def __init__(
format: str = "parquet",
schema: Optional[RelationSchema] = None,
parquet_row_group_size: int = 5000,
sort_by: Optional[str] = None,
sort_by: Optional[Union[str, List]] = None,
Copy link

Copilot AI Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint List is too generic. It should be List[Union[str, Tuple[str, str]]] to clearly indicate it accepts either a list of column names or a list of tuples with column name and sort direction.

Copilot uses AI. Check for mistakes.
use_dictionary: Optional[Union[bool, List[str]]] = None,
**kwargs,
):
self.format = format
self.maximum_blob_size = blob_size
self.parquet_row_group_size = parquet_row_group_size
self.sort_by = sort_by
self.use_dictionary = use_dictionary

if format not in SUPPORTED_FORMATS_ALGORITHMS:
raise ValueError(
Expand Down Expand Up @@ -166,12 +168,18 @@ def commit(self):

# sort the table if sort_by is specified
if self.sort_by:
pytable = pytable.sort_by(self.sort_by)
# Convert list of strings to PyArrow format
sort_spec = self.sort_by
if isinstance(self.sort_by, list) and all(isinstance(item, str) for item in self.sort_by):
Copy link

Copilot AI Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation logic should be extracted to a separate method or moved to the constructor for earlier validation and better error handling. Currently, invalid sort_by values would only be caught during commit().

Copilot uses AI. Check for mistakes.
# Convert list of strings to list of tuples with default ascending order
sort_spec = [(col, "ascending") for col in self.sort_by]
pytable = pytable.sort_by(sort_spec)

tempfile = io.BytesIO()
pyarrow.parquet.write_table(
pytable, where=tempfile, row_group_size=self.parquet_row_group_size
)
write_kwargs = {"row_group_size": self.parquet_row_group_size}
if self.use_dictionary is not None:
write_kwargs["use_dictionary"] = self.use_dictionary
pyarrow.parquet.write_table(pytable, where=tempfile, **write_kwargs)

tempfile.seek(0)
write_buffer = tempfile.read()
Expand Down
2 changes: 1 addition & 1 deletion mabel/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Store the version here so:
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
__version__ = "0.6.28"
__version__ = "0.6.29"

# nodoc - don't add to the documentation wiki
188 changes: 188 additions & 0 deletions tests/test_writer_parquet_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,194 @@ def test_parquet_default_row_group_size():
shutil.rmtree("_temp_default", ignore_errors=True)


def test_parquet_sorting_list_of_strings():
"""Test that sort_by parameter can accept a list of strings"""
shutil.rmtree("_temp_sort_list", ignore_errors=True)

# Create a writer with sorting by list of strings
w = BatchWriter(
inner_writer=DiskWriter,
dataset="_temp_sort_list",
format="parquet",
date=datetime.datetime.utcnow().date(),
schema=[
{"name": "id", "type": "INTEGER"},
{"name": "category", "type": "VARCHAR"},
{"name": "value", "type": "VARCHAR"}
],
sort_by=["category", "id"], # Sort by category first, then id
parquet_row_group_size=5000,
)

# Write records in random order
records_to_write = [
{"id": 3, "category": "B", "value": "value_3"},
{"id": 1, "category": "A", "value": "value_1"},
{"id": 4, "category": "B", "value": "value_4"},
{"id": 2, "category": "A", "value": "value_2"},
{"id": 5, "category": "C", "value": "value_5"},
]

for record in records_to_write:
w.append(record)

w.finalize()

# Read back and verify the data is sorted by category, then id
r = Reader(inner_reader=DiskReader, dataset="_temp_sort_list")
records = list(r)

assert len(records) == 5, f"Expected 5 records, got {len(records)}"

# Check that records are sorted by category first, then by id
expected_order = [
{"id": 1, "category": "A", "value": "value_1"},
{"id": 2, "category": "A", "value": "value_2"},
{"id": 3, "category": "B", "value": "value_3"},
{"id": 4, "category": "B", "value": "value_4"},
{"id": 5, "category": "C", "value": "value_5"},
]

for i, record in enumerate(records):
assert record["id"] == expected_order[i]["id"], f"Record {i} id mismatch: {record['id']} != {expected_order[i]['id']}"
assert record["category"] == expected_order[i]["category"], f"Record {i} category mismatch"

shutil.rmtree("_temp_sort_list", ignore_errors=True)


def test_parquet_sorting_single_column_list():
"""Test that sort_by parameter can accept a list with a single string"""
shutil.rmtree("_temp_sort_single_list", ignore_errors=True)

# Create a writer with sorting by a list containing a single column
w = BatchWriter(
inner_writer=DiskWriter,
dataset="_temp_sort_single_list",
format="parquet",
date=datetime.datetime.utcnow().date(),
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
sort_by=["id"], # Sort by id column as a list
parquet_row_group_size=5000,
)

# Write records in reverse order
for i in range(10, 0, -1):
w.append({"id": i, "value": f"value_{i}"})

w.finalize()

# Read back and verify the data is sorted
r = Reader(inner_reader=DiskReader, dataset="_temp_sort_single_list")
records = list(r)

assert len(records) == 10, f"Expected 10 records, got {len(records)}"

# Check that records are sorted by id
ids = [record["id"] for record in records]
assert ids == list(range(1, 11)), f"Records are not sorted correctly: {ids}"

shutil.rmtree("_temp_sort_single_list", ignore_errors=True)


def test_parquet_dictionary_encoding_all():
"""Test that use_dictionary parameter can be set to True for all columns"""
shutil.rmtree("_temp_dict_all", ignore_errors=True)

w = BatchWriter(
inner_writer=DiskWriter,
dataset="_temp_dict_all",
format="parquet",
date=datetime.datetime.utcnow().date(),
schema=[
{"name": "id", "type": "INTEGER"},
{"name": "category", "type": "VARCHAR"}
],
use_dictionary=True, # Enable dictionary encoding for all columns
)

# Write records with repeated category values (good for dictionary encoding)
for i in range(100):
w.append({"id": i, "category": f"category_{i % 5}"})

w.finalize()

# Read back and verify data
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_all")
records = list(r)
assert len(records) == 100, f"Expected 100 records, got {len(records)}"

shutil.rmtree("_temp_dict_all", ignore_errors=True)


def test_parquet_dictionary_encoding_disabled():
"""Test that use_dictionary parameter can be set to False to disable dictionary encoding"""
shutil.rmtree("_temp_dict_disabled", ignore_errors=True)

w = BatchWriter(
inner_writer=DiskWriter,
dataset="_temp_dict_disabled",
format="parquet",
date=datetime.datetime.utcnow().date(),
schema=[
{"name": "id", "type": "INTEGER"},
{"name": "category", "type": "VARCHAR"}
],
use_dictionary=False, # Disable dictionary encoding
)

# Write records
for i in range(100):
w.append({"id": i, "category": f"category_{i % 5}"})

w.finalize()

# Read back and verify data
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_disabled")
records = list(r)
assert len(records) == 100, f"Expected 100 records, got {len(records)}"

shutil.rmtree("_temp_dict_disabled", ignore_errors=True)


def test_parquet_dictionary_encoding_specific_columns():
"""Test that use_dictionary parameter can specify specific columns for dictionary encoding"""
shutil.rmtree("_temp_dict_specific", ignore_errors=True)

w = BatchWriter(
inner_writer=DiskWriter,
dataset="_temp_dict_specific",
format="parquet",
date=datetime.datetime.utcnow().date(),
schema=[
{"name": "id", "type": "INTEGER"},
{"name": "category", "type": "VARCHAR"},
{"name": "value", "type": "VARCHAR"}
],
use_dictionary=["category"], # Only encode 'category' column with dictionary
)

# Write records with repeated category values
for i in range(100):
w.append({
"id": i,
"category": f"category_{i % 5}",
"value": f"unique_value_{i}"
})

w.finalize()

# Read back and verify data
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_specific")
records = list(r)
assert len(records) == 100, f"Expected 100 records, got {len(records)}"

# Verify the data is correct
assert records[0]["category"] == "category_0"
assert records[50]["category"] == "category_0"

shutil.rmtree("_temp_dict_specific", ignore_errors=True)


if __name__ == "__main__": # pragma: no cover
from tests.helpers.runner import run_tests

Expand Down
Loading