From e7cbe87dc1f74879c410e1d95daecb4fa57bf158 Mon Sep 17 00:00:00 2001 From: Matt Katz Date: Thu, 25 Jun 2026 15:04:44 -0700 Subject: [PATCH 1/2] feat[datafusion]: push down list_length expression Signed-off-by: Matt Katz --- Cargo.lock | 1 + Cargo.toml | 1 + vortex-datafusion/Cargo.toml | 1 + vortex-datafusion/src/convert/exprs.rs | 181 +++++++++++++++++++++- vortex-datafusion/src/persistent/tests.rs | 111 +++++++++++++ 5 files changed, 293 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 66457601e06..02dc8d43f25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9831,6 +9831,7 @@ dependencies = [ "datafusion-execution 54.0.0", "datafusion-expr 54.0.0", "datafusion-functions 54.0.0", + "datafusion-functions-nested 54.0.0", "datafusion-physical-expr 54.0.0", "datafusion-physical-expr-adapter 54.0.0", "datafusion-physical-expr-common 54.0.0", diff --git a/Cargo.toml b/Cargo.toml index deed8b8d58a..876a0906e17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,6 +144,7 @@ datafusion-datasource = { version = "54", default-features = false } datafusion-execution = { version = "54" } datafusion-expr = { version = "54" } datafusion-functions = { version = "54" } +datafusion-functions-nested = { version = "54" } datafusion-physical-expr = { version = "54" } datafusion-physical-expr-adapter = { version = "54" } datafusion-physical-expr-common = { version = "54" } diff --git a/vortex-datafusion/Cargo.toml b/vortex-datafusion/Cargo.toml index 4aaefec35ff..72ebb3362b4 100644 --- a/vortex-datafusion/Cargo.toml +++ b/vortex-datafusion/Cargo.toml @@ -24,6 +24,7 @@ datafusion-datasource = { workspace = true, default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index 7b02c6f531a..69e329dfcef 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -7,12 +7,14 @@ use arrow_schema::DataType; use arrow_schema::Field; use arrow_schema::Schema; use datafusion_common::Result as DFResult; +use datafusion_common::ScalarValue; use datafusion_common::exec_datafusion_err; use datafusion_common::tree_node::TreeNode; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_expr::Operator as DFOperator; use datafusion_functions::core::getfield::GetFieldFunc; use datafusion_functions::string::octet_length::OctetLengthFunc; +use datafusion_functions_nested::length::ArrayLength; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::projection::ProjectionExpr; @@ -32,6 +34,7 @@ use vortex::expr::get_item; use vortex::expr::is_not_null; use vortex::expr::is_null; use vortex::expr::list_contains; +use vortex::expr::list_length; use vortex::expr::lit; use vortex::expr::nested_case_when; use vortex::expr::not; @@ -155,6 +158,36 @@ impl DefaultExpressionConvertor { Ok(cast(byte_length(input), return_dtype)) } + /// Attempts to convert DataFusion's `array_length` function (aliased as `list_length`) to + /// Vortex `list_length`. + /// + /// Supports the single-argument form `array_length(arr)` and the equivalent two-argument + /// form with an explicit first dimension `array_length(arr, 1)`. Higher dimensions recurse + /// into nested lists and are rejected by [`can_array_length_be_pushed_down`] before reaching + /// this point. + fn try_convert_array_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult { + let Some(input) = array_length_input(scalar_fn) else { + return Err(exec_datafusion_err!( + "array_length pushdown supports only the one-argument form or an explicit first \ + dimension" + )); + }; + + let input = self.convert(input.as_ref())?; + // Both DataFusion `array_length` and Vortex `list_length` return UInt64; the cast aligns + // nullability with DataFusion's declared return type. + let return_dtype = self + .session + .arrow() + .from_arrow_field(&Field::new( + "", + scalar_fn.return_type().clone(), + scalar_fn.nullable(), + )) + .map_err(|e| exec_datafusion_err!("Failed to convert return type to dtype: {e}"))?; + Ok(cast(list_length(input), return_dtype)) + } + /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression. fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult { if let Some(octet_length_fn) = @@ -163,6 +196,12 @@ impl DefaultExpressionConvertor { return self.try_convert_octet_length(octet_length_fn); } + if let Some(array_length_fn) = + ScalarFunctionExpr::try_downcast_func::(scalar_fn) + { + return self.try_convert_array_length(array_length_fn); + } + if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::(scalar_fn) { // DataFusion's GetFieldFunc flattens nested field access into a single call @@ -511,6 +550,7 @@ fn is_convertible_expr(expr: &Arc) -> bool { || expr.downcast_ref::().is_some_and(|sf| { ScalarFunctionExpr::try_downcast_func::(sf).is_some() || ScalarFunctionExpr::try_downcast_func::(sf).is_some() + || ScalarFunctionExpr::try_downcast_func::(sf).is_some() }) } @@ -572,14 +612,20 @@ fn supported_data_types(dt: &DataType) -> bool { } /// Checks if a scalar function can be pushed down. -/// Currently GetFieldFunc and OctetLengthFunc are supported. +/// Currently GetFieldFunc, OctetLengthFunc, and ArrayLength are supported. fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool { if ScalarFunctionExpr::try_downcast_func::(scalar_fn).is_some() { return true; } - ScalarFunctionExpr::try_downcast_func::(scalar_fn) + if ScalarFunctionExpr::try_downcast_func::(scalar_fn) .is_some_and(|octet_length| can_octet_length_be_pushed_down(octet_length, schema)) + { + return true; + } + + ScalarFunctionExpr::try_downcast_func::(scalar_fn) + .is_some_and(|array_length| can_array_length_be_pushed_down(array_length, schema)) } fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool { @@ -598,6 +644,59 @@ fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Sche }) && can_be_pushed_down_impl(input, schema) } +fn can_array_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool { + let Some(input) = array_length_input(scalar_fn) else { + return false; + }; + + // The argument must resolve to a list type. We gate on the resolved data type rather than + // `can_be_pushed_down_impl`, since list columns are intentionally rejected there. We still + // require the argument to be a convertible expression (e.g. a column or struct field access). + input.data_type(schema).as_ref().is_ok_and(|data_type| { + matches!( + data_type, + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) + ) + }) && is_convertible_expr(input) +} + +/// Returns the list argument of an `array_length` call if the call is a form we can rewrite to +/// `list_length`: either the single-argument form `array_length(arr)`, or the two-argument form +/// with an explicit first dimension `array_length(arr, 1)`, which is equivalent. Higher +/// dimensions recurse into nested lists and are not supported. +fn array_length_input(scalar_fn: &ScalarFunctionExpr) -> Option<&Arc> { + match scalar_fn.args() { + [input] => Some(input), + [input, dimension] if is_dimension_one(dimension) => Some(input), + _ => None, + } +} + +/// Returns true if `expr` is an integer literal equal to 1. The dimension argument of +/// `array_length` is coerced to `Int64`, but we accept any integer width defensively. +fn is_dimension_one(expr: &Arc) -> bool { + let Some(literal) = expr.downcast_ref::() else { + return false; + }; + + let dimension = match literal.value() { + ScalarValue::Int8(Some(v)) => i64::from(*v), + ScalarValue::Int16(Some(v)) => i64::from(*v), + ScalarValue::Int32(Some(v)) => i64::from(*v), + ScalarValue::Int64(Some(v)) => *v, + ScalarValue::UInt8(Some(v)) => i64::from(*v), + ScalarValue::UInt16(Some(v)) => i64::from(*v), + ScalarValue::UInt32(Some(v)) => i64::from(*v), + ScalarValue::UInt64(Some(v)) => match i64::try_from(*v) { + Ok(v) => v, + Err(_) => return false, + }, + _ => return false, + }; + + dimension == 1 +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -652,6 +751,21 @@ mod tests { ) } + fn array_length_expr( + args: Vec>, + schema: &Schema, + ) -> Arc { + Arc::new( + ScalarFunctionExpr::try_new( + Arc::new(ScalarUDF::from(ArrayLength::new())), + args, + schema, + Arc::new(ConfigOptions::new()), + ) + .unwrap(), + ) + } + #[test] fn test_make_vortex_predicate_empty() { let expr_convertor = DefaultExpressionConvertor::default(); @@ -798,6 +912,23 @@ mod tests { "); } + #[rstest] + fn test_expr_from_df_array_length(test_schema: Schema) { + let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let array_length = array_length_expr(vec![expr], &test_schema); + + let result = DefaultExpressionConvertor::default() + .convert(array_length.as_ref()) + .unwrap(); + + assert_snapshot!(result.display_tree().to_string(), @r" + vortex.cast(u64?) + └── input: vortex.list.length() + └── input: vortex.get_item(unsupported_list) + └── input: vortex.root() + "); + } + #[rstest] // Supported types #[case::null(DataType::Null, true)] @@ -974,6 +1105,52 @@ mod tests { assert!(!can_be_pushed_down_impl(&octet_length, &test_schema)); } + #[rstest] + fn test_can_be_pushed_down_array_length_supported(test_schema: Schema) { + let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let array_length = array_length_expr(vec![expr], &test_schema); + + assert!(can_be_pushed_down_impl(&array_length, &test_schema)); + } + + #[rstest] + fn test_can_be_pushed_down_array_length_unsupported_operand(test_schema: Schema) { + // `array_length` over a non-list column cannot be pushed down. + let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc; + let array_length = Arc::new(ScalarFunctionExpr::new( + "array_length", + Arc::new(ScalarUDF::from(ArrayLength::new())), + vec![expr], + Arc::new(Field::new("array_length", DataType::UInt64, true)), + Arc::new(ConfigOptions::new()), + )) as Arc; + + assert!(!can_be_pushed_down_impl(&array_length, &test_schema)); + } + + #[rstest] + fn test_can_be_pushed_down_array_length_dimension_one_supported(test_schema: Schema) { + // `array_length(arr, 1)` is the first-dimension length, equivalent to `list_length`. + let list = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let dimension = + Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(1)))) as Arc; + let array_length = array_length_expr(vec![list, dimension], &test_schema); + + assert!(can_be_pushed_down_impl(&array_length, &test_schema)); + } + + #[rstest] + fn test_can_be_pushed_down_array_length_higher_dimension_not_supported(test_schema: Schema) { + // Dimensions other than 1 recurse into nested lists, which `list_length` does not model, + // so they must not be pushed down. + let list = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let dimension = + Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(2)))) as Arc; + let array_length = array_length_expr(vec![list, dimension], &test_schema); + + assert!(!can_be_pushed_down_impl(&array_length, &test_schema)); + } + // https://github.com/vortex-data/vortex/issues/6211 #[tokio::test] async fn test_cast_int_to_string() -> anyhow::Result<()> { diff --git a/vortex-datafusion/src/persistent/tests.rs b/vortex-datafusion/src/persistent/tests.rs index 220a1477b13..65660df8e21 100644 --- a/vortex-datafusion/src/persistent/tests.rs +++ b/vortex-datafusion/src/persistent/tests.rs @@ -13,6 +13,8 @@ use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionConfig; use datafusion::prelude::SessionContext; use datafusion_common::GetExt; +use datafusion_expr::ScalarUDF; +use datafusion_functions_nested::length::ArrayLength; use datafusion_physical_plan::display::DisplayableExecutionPlan; use insta::assert_snapshot; use object_store::ObjectStore; @@ -21,6 +23,7 @@ use rstest::rstest; use vortex::VortexSessionDefault; use vortex::array::IntoArray; use vortex::array::arrays::ChunkedArray; +use vortex::array::arrays::ListArray; use vortex::array::arrays::StructArray; use vortex::array::arrays::VarBinArray; use vortex::array::validity::Validity; @@ -233,6 +236,114 @@ async fn test_octet_length_pushdown() -> anyhow::Result<()> { Ok(()) } +#[tokio::test] +async fn test_array_length_pushdown() -> anyhow::Result<()> { + // `new(true)` enables projection pushdown so the `array_length` projection is pushed into + // the Vortex scan rather than evaluated above it. + let ctx = TestSessionContext::new(true); + // `array_length` is a nested-array function; the test session is built without the + // `nested_expressions` default feature, so register it explicitly. + ctx.session + .register_udf(ScalarUDF::from(ArrayLength::new())); + let session = VortexSession::default(); + + // Five lists with element counts 3, 4, 0, 5, 2 respectively. The empty list exercises the + // 0 (not NULL) result that both DataFusion's `array_length` and Vortex's `list_length` + // produce for a non-null empty list. + let elements = buffer![ + 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140 + ] + .into_array(); + let offsets = buffer![0i32, 3, 7, 7, 12, 14].into_array(); + let int_list = ListArray::try_new(elements, offsets, Validity::AllValid)?.into_array(); + let ids = buffer![0i32, 1, 2, 3, 4].into_array(); + + let st = StructArray::try_new( + ["id", "int_list"].into(), + vec![ids, int_list], + 5, + Validity::NonNullable, + )?; + + let mut writer = ObjectStoreWrite::new(Arc::clone(&ctx.store), &"list.vortex".into()).await?; + session + .write_options() + .write(&mut writer, st.into_array().to_array_stream()) + .await?; + writer.shutdown().await?; + + // Projection: `array_length` computed from the list offsets without materializing elements. + let result = ctx + .session + .sql( + "SELECT id, array_length(int_list) AS len \ + FROM '/list.vortex' \ + ORDER BY id", + ) + .await? + .collect() + .await?; + + assert_eq!( + result[0].schema().field_with_name("len")?.data_type(), + &DataType::UInt64 + ); + assert_snapshot!(pretty_format_batches(&result)?, @r" + +----+-----+ + | id | len | + +----+-----+ + | 0 | 3 | + | 1 | 4 | + | 2 | 0 | + | 3 | 5 | + | 4 | 2 | + +----+-----+ + "); + + // The explicit first-dimension form `array_length(int_list, 1)` is equivalent and pushes + // down identically. + let with_dimension = ctx + .session + .sql( + "SELECT array_length(int_list, 1) AS len \ + FROM '/list.vortex' \ + ORDER BY id", + ) + .await? + .collect() + .await?; + + assert_snapshot!(pretty_format_batches(&with_dimension)?, @r" + +-----+ + | len | + +-----+ + | 3 | + | 4 | + | 0 | + | 5 | + | 2 | + +-----+ + "); + + // Filter: `WHERE array_length(int_list) >= 4` keeps the 4- and 5-element lists. + let filtered = ctx + .session + .sql("SELECT COUNT(*) AS cnt FROM '/list.vortex' WHERE array_length(int_list) >= 4") + .await? + .collect() + .await?; + + assert_snapshot!(pretty_format_batches(&filtered)?, @r" + +-----+ + | cnt | + +-----+ + | 2 | + +-----+ + "); + + Ok(()) +} + #[tokio::test] async fn create_table_ordered_by() -> anyhow::Result<()> { let ctx = TestSessionContext::default(); From 20830907b3b67c920113f703468016d1a10c5652 Mon Sep 17 00:00:00 2001 From: Matt Katz Date: Thu, 25 Jun 2026 17:31:02 -0700 Subject: [PATCH 2/2] comments and test change Signed-off-by: Matt Katz --- vortex-datafusion/src/convert/exprs.rs | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index 69e329dfcef..8f822933898 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -162,9 +162,7 @@ impl DefaultExpressionConvertor { /// Vortex `list_length`. /// /// Supports the single-argument form `array_length(arr)` and the equivalent two-argument - /// form with an explicit first dimension `array_length(arr, 1)`. Higher dimensions recurse - /// into nested lists and are rejected by [`can_array_length_be_pushed_down`] before reaching - /// this point. + /// form with an explicit first dimension `array_length(arr, 1)`. fn try_convert_array_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult { let Some(input) = array_length_input(scalar_fn) else { return Err(exec_datafusion_err!( @@ -174,8 +172,6 @@ impl DefaultExpressionConvertor { }; let input = self.convert(input.as_ref())?; - // Both DataFusion `array_length` and Vortex `list_length` return UInt64; the cast aligns - // nullability with DataFusion's declared return type. let return_dtype = self .session .arrow() @@ -732,7 +728,7 @@ mod tests { true, ), Field::new( - "unsupported_list", + "tags", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, ), @@ -914,7 +910,7 @@ mod tests { #[rstest] fn test_expr_from_df_array_length(test_schema: Schema) { - let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let array_length = array_length_expr(vec![expr], &test_schema); let result = DefaultExpressionConvertor::default() @@ -924,7 +920,7 @@ mod tests { assert_snapshot!(result.display_tree().to_string(), @r" vortex.cast(u64?) └── input: vortex.list.length() - └── input: vortex.get_item(unsupported_list) + └── input: vortex.get_item(tags) └── input: vortex.root() "); } @@ -993,7 +989,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) { let col_expr = - Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + Arc::new(df_expr::Column::new("tags", 5)) as Arc; assert!(!can_be_pushed_down_impl(&col_expr, &test_schema)); } @@ -1050,7 +1046,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) { - let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let left = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let right = Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc; let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right)) @@ -1073,7 +1069,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) { - let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some( "test%".to_string(), )))) as Arc; @@ -1093,7 +1089,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_octet_length_unsupported_operand(test_schema: Schema) { - let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let octet_length = Arc::new(ScalarFunctionExpr::new( "octet_length", Arc::new(ScalarUDF::from(OctetLengthFunc::new())), @@ -1107,7 +1103,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_array_length_supported(test_schema: Schema) { - let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let array_length = array_length_expr(vec![expr], &test_schema); assert!(can_be_pushed_down_impl(&array_length, &test_schema)); @@ -1131,7 +1127,7 @@ mod tests { #[rstest] fn test_can_be_pushed_down_array_length_dimension_one_supported(test_schema: Schema) { // `array_length(arr, 1)` is the first-dimension length, equivalent to `list_length`. - let list = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let dimension = Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(1)))) as Arc; let array_length = array_length_expr(vec![list, dimension], &test_schema); @@ -1143,7 +1139,7 @@ mod tests { fn test_can_be_pushed_down_array_length_higher_dimension_not_supported(test_schema: Schema) { // Dimensions other than 1 recurse into nested lists, which `list_length` does not model, // so they must not be pushed down. - let list = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc; + let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc; let dimension = Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(2)))) as Arc; let array_length = array_length_expr(vec![list, dimension], &test_schema);