Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 71 additions & 46 deletions parquet-variant-compute/src/shred_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,24 +493,26 @@ impl IntoShreddingField for (DataType, bool) {
/// use parquet_variant::{VariantPath, VariantPathElement};
/// use parquet_variant_compute::ShreddedSchemaBuilder;
///
/// // Define the shredding schema using the builder
/// let shredding_type = ShreddedSchemaBuilder::default()
/// fn main() -> Result<(), arrow::error::ArrowError> {
/// // Define the shredding schema using the builder
/// let shredding_type = ShreddedSchemaBuilder::default()
/// // store the "time" field as a separate UTC timestamp
/// .with_path("time", (&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), true))
/// .with_path("time", (&DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), true))?
/// // store hostname as non-nullable Utf8
/// .with_path("hostname", (&DataType::Utf8, false))
/// .with_path("hostname", (&DataType::Utf8, false))?
/// // pass a FieldRef directly
/// .with_path(
/// "metadata.trace_id",
/// Arc::new(Field::new("trace_id", DataType::FixedSizeBinary(16), false)),
/// )
/// )?
/// // field name with a dot: use VariantPath to avoid splitting
/// .with_path(
/// VariantPath::from_iter([VariantPathElement::from("metrics.cpu")]),
/// &DataType::Float64,
/// )
/// )?
/// .build();
///
/// Ok(())
/// }
/// // The shredding_type can now be passed to shred_variant:
/// // let shredded = shred_variant(&input, &shredding_type)?;
/// ```
Expand All @@ -536,14 +538,17 @@ impl ShreddedSchemaBuilder {
/// * `path` - Anything convertible to [`VariantPath`] (e.g., a `&str`)
/// * `field` - Anything convertible via [`IntoShreddingField`] (e.g. `FieldRef`,
/// `&DataType`, or `(&DataType, bool)` to control nullability)
pub fn with_path<'a, P, F>(mut self, path: P, field: F) -> Self
pub fn with_path<'a, P, F>(mut self, path: P, field: F) -> Result<Self>
where
P: Into<VariantPath<'a>>,
P: TryInto<VariantPath<'a>>,
P::Error: std::fmt::Debug,
F: IntoShreddingField,
{
let path: VariantPath<'a> = path.into();
let path: VariantPath<'a> = path
.try_into()
.map_err(|e| ArrowError::InvalidArgumentError(format!("{:?}", e)))?;
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this line and L544, because we'll have VaraintPath -> VariantPath, which is Infallible, and there is no From<Infallible> for ArrowError, ArrowError is located in arrow-schema, not sure if it's appropriate to make the change there. If add From<Infallible> for ArrowError is better, I can change it.

self.root.insert_path(&path, field.into_shredding_field());
self
Ok(self)
}

/// Build the final [`DataType`].
Expand Down Expand Up @@ -1558,7 +1563,7 @@ mod tests {
}

#[test]
fn test_object_shredding_comprehensive() {
fn test_object_shredding_comprehensive() -> Result<()> {
let input = build_variant_array(vec![
// Row 0: Fully shredded object
VariantRow::Object(vec![
Expand Down Expand Up @@ -1596,8 +1601,8 @@ mod tests {
// Create target schema: struct<score: float64, age: int64>
// Both types are supported for shredding
let target_schema = ShreddedSchemaBuilder::default()
.with_path("score", &DataType::Float64)
.with_path("age", &DataType::Int64)
.with_path("score", &DataType::Float64)?
.with_path("age", &DataType::Int64)?
.build();

let result = shred_variant(&input, &target_schema).unwrap();
Expand Down Expand Up @@ -1903,6 +1908,7 @@ mod tests {
}),
}),
);
Ok(())
}

#[test]
Expand Down Expand Up @@ -1998,7 +2004,7 @@ mod tests {
}

#[test]
fn test_object_different_schemas() {
fn test_object_different_schemas() -> Result<()> {
// Create object with multiple fields
let input = build_variant_array(vec![VariantRow::Object(vec![
("id", VariantValue::from(123i32)),
Expand All @@ -2008,34 +2014,36 @@ mod tests {

// Test with schema containing only id field
let schema1 = ShreddedSchemaBuilder::default()
.with_path("id", &DataType::Int32)
.with_path("id", &DataType::Int32)?
.build();
let result1 = shred_variant(&input, &schema1).unwrap();
let value_field1 = result1.value_field().unwrap();
assert!(!value_field1.is_null(0)); // should contain {"age": 25, "score": 95.5}

// Test with schema containing id and age fields
let schema2 = ShreddedSchemaBuilder::default()
.with_path("id", &DataType::Int32)
.with_path("age", &DataType::Int64)
.with_path("id", &DataType::Int32)?
.with_path("age", &DataType::Int64)?
.build();
let result2 = shred_variant(&input, &schema2).unwrap();
let value_field2 = result2.value_field().unwrap();
assert!(!value_field2.is_null(0)); // should contain {"score": 95.5}

// Test with schema containing all fields
let schema3 = ShreddedSchemaBuilder::default()
.with_path("id", &DataType::Int32)
.with_path("age", &DataType::Int64)
.with_path("score", &DataType::Float64)
.with_path("id", &DataType::Int32)?
.with_path("age", &DataType::Int64)?
.with_path("score", &DataType::Float64)?
.build();
let result3 = shred_variant(&input, &schema3).unwrap();
let value_field3 = result3.value_field().unwrap();
assert!(value_field3.is_null(0)); // fully shredded, no remaining fields

Ok(())
}

#[test]
fn test_uuid_shredding_in_objects() {
fn test_uuid_shredding_in_objects() -> Result<()> {
let mock_uuid_1 = Uuid::new_v4();
let mock_uuid_2 = Uuid::new_v4();
let mock_uuid_3 = Uuid::new_v4();
Expand Down Expand Up @@ -2069,8 +2077,8 @@ mod tests {
]);

let target_schema = ShreddedSchemaBuilder::default()
.with_path("id", DataType::FixedSizeBinary(16))
.with_path("session_id", DataType::FixedSizeBinary(16))
.with_path("id", DataType::FixedSizeBinary(16))?
.with_path("session_id", DataType::FixedSizeBinary(16))?
.build();

let result = shred_variant(&input, &target_schema).unwrap();
Expand Down Expand Up @@ -2201,6 +2209,8 @@ mod tests {

// Row 5: Null
assert!(result.is_null(5));

Ok(())
}

#[test]
Expand Down Expand Up @@ -2251,10 +2261,10 @@ mod tests {
}

#[test]
fn test_variant_schema_builder_simple() {
fn test_variant_schema_builder_simple() -> Result<()> {
let shredding_type = ShreddedSchemaBuilder::default()
.with_path("a", &DataType::Int64)
.with_path("b", &DataType::Float64)
.with_path("a", &DataType::Int64)?
.with_path("b", &DataType::Float64)?
.build();

assert_eq!(
Expand All @@ -2264,14 +2274,16 @@ mod tests {
Field::new("b", DataType::Float64, true),
]))
);

Ok(())
}

#[test]
fn test_variant_schema_builder_nested() {
fn test_variant_schema_builder_nested() -> Result<()> {
let shredding_type = ShreddedSchemaBuilder::default()
.with_path("a", &DataType::Int64)
.with_path("b.c", &DataType::Utf8)
.with_path("b.d", &DataType::Float64)
.with_path("a", &DataType::Int64)?
.with_path("b.c", &DataType::Utf8)?
.with_path("b.d", &DataType::Float64)?
.build();

assert_eq!(
Expand All @@ -2288,13 +2300,15 @@ mod tests {
),
]))
);

Ok(())
}

#[test]
fn test_variant_schema_builder_with_path_variant_path_arg() {
fn test_variant_schema_builder_with_path_variant_path_arg() -> Result<()> {
let path = VariantPath::from_iter([VariantPathElement::from("a.b")]);
let shredding_type = ShreddedSchemaBuilder::default()
.with_path(path, &DataType::Int64)
.with_path(path, &DataType::Int64)?
.build();

match shredding_type {
Expand All @@ -2305,16 +2319,18 @@ mod tests {
}
_ => panic!("expected struct data type"),
}

Ok(())
}

#[test]
fn test_variant_schema_builder_custom_nullability() {
fn test_variant_schema_builder_custom_nullability() -> Result<()> {
let shredding_type = ShreddedSchemaBuilder::default()
.with_path(
"foo",
Arc::new(Field::new("should_be_renamed", DataType::Utf8, false)),
)
.with_path("bar", (&DataType::Int64, false))
)?
.with_path("bar", (&DataType::Int64, false))?
.build();

let DataType::Struct(fields) = shredding_type else {
Expand All @@ -2328,10 +2344,12 @@ mod tests {
let bar = fields.iter().find(|f| f.name() == "bar").unwrap();
assert_eq!(bar.data_type(), &DataType::Int64);
assert!(!bar.is_nullable());

Ok(())
}

#[test]
fn test_variant_schema_builder_with_shred_variant() {
fn test_variant_schema_builder_with_shred_variant() -> Result<()> {
let input = build_variant_array(vec![
VariantRow::Object(vec![
("time", VariantValue::from(1234567890i64)),
Expand All @@ -2346,8 +2364,8 @@ mod tests {
]);

let shredding_type = ShreddedSchemaBuilder::default()
.with_path("time", &DataType::Int64)
.with_path("hostname", &DataType::Utf8)
.with_path("time", &DataType::Int64)?
.with_path("hostname", &DataType::Utf8)?
.build();

let result = shred_variant(&input, &shredding_type).unwrap();
Expand Down Expand Up @@ -2424,13 +2442,15 @@ mod tests {

// Row 2
assert!(result.is_null(2));

Ok(())
}

#[test]
fn test_variant_schema_builder_conflicting_path() {
fn test_variant_schema_builder_conflicting_path() -> Result<()> {
let shredding_type = ShreddedSchemaBuilder::default()
.with_path("a", &DataType::Int64)
.with_path("a", &DataType::Float64)
.with_path("a", &DataType::Int64)?
.with_path("a", &DataType::Float64)?
.build();

assert_eq!(
Expand All @@ -2439,25 +2459,30 @@ mod tests {
vec![Field::new("a", DataType::Float64, true),]
))
);

Ok(())
}

#[test]
fn test_variant_schema_builder_root_path() {
fn test_variant_schema_builder_root_path() -> Result<()> {
let path = VariantPath::new(vec![]);
let shredding_type = ShreddedSchemaBuilder::default()
.with_path(path, &DataType::Int64)
.with_path(path, &DataType::Int64)?
.build();

assert_eq!(shredding_type, DataType::Int64);

Ok(())
}

#[test]
fn test_variant_schema_builder_empty_path() {
fn test_variant_schema_builder_empty_path() -> Result<()> {
let shredding_type = ShreddedSchemaBuilder::default()
.with_path("", &DataType::Int64)
.with_path("", &DataType::Int64)?
.build();

assert_eq!(shredding_type, DataType::Int64);
Ok(())
}

#[test]
Expand Down
Loading
Loading