Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ datafusion-datasource = { version = "54", default-features = false }
datafusion-execution = { version = "54" }
datafusion-expr = { version = "54" }
datafusion-functions = { version = "54" }
datafusion-functions-nested = { version = "54" }
datafusion-physical-expr = { version = "54" }
datafusion-physical-expr-adapter = { version = "54" }
datafusion-physical-expr-common = { version = "54" }
Expand Down
1 change: 1 addition & 0 deletions vortex-datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ datafusion-datasource = { workspace = true, default-features = false }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-functions-nested = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-adapter = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
Expand Down
187 changes: 180 additions & 7 deletions vortex-datafusion/src/convert/exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use arrow_schema::DataType;
use arrow_schema::Field;
use arrow_schema::Schema;
use datafusion_common::Result as DFResult;
use datafusion_common::ScalarValue;
use datafusion_common::exec_datafusion_err;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_expr::Operator as DFOperator;
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_functions::string::octet_length::OctetLengthFunc;
use datafusion_functions_nested::length::ArrayLength;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr::projection::ProjectionExpr;
Expand All @@ -32,6 +34,7 @@ use vortex::expr::get_item;
use vortex::expr::is_not_null;
use vortex::expr::is_null;
use vortex::expr::list_contains;
use vortex::expr::list_length;
use vortex::expr::lit;
use vortex::expr::nested_case_when;
use vortex::expr::not;
Expand Down Expand Up @@ -155,6 +158,32 @@ impl DefaultExpressionConvertor {
Ok(cast(byte_length(input), return_dtype))
}

/// Attempts to convert DataFusion's `array_length` function (aliased as `list_length`) to
/// Vortex `list_length`.
///
/// Supports the single-argument form `array_length(arr)` and the equivalent two-argument
/// form with an explicit first dimension `array_length(arr, 1)`.
fn try_convert_array_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
let Some(input) = array_length_input(scalar_fn) else {
return Err(exec_datafusion_err!(
"array_length pushdown supports only the one-argument form or an explicit first \
dimension"
));
};

let input = self.convert(input.as_ref())?;
let return_dtype = self
.session
.arrow()
.from_arrow_field(&Field::new(
"",
scalar_fn.return_type().clone(),
scalar_fn.nullable(),
))
.map_err(|e| exec_datafusion_err!("Failed to convert return type to dtype: {e}"))?;
Ok(cast(list_length(input), return_dtype))
}

/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
if let Some(octet_length_fn) =
Expand All @@ -163,6 +192,12 @@ impl DefaultExpressionConvertor {
return self.try_convert_octet_length(octet_length_fn);
}

if let Some(array_length_fn) =
ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
{
return self.try_convert_array_length(array_length_fn);
}

if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
{
// DataFusion's GetFieldFunc flattens nested field access into a single call
Expand Down Expand Up @@ -511,6 +546,7 @@ fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
|| expr.downcast_ref::<ScalarFunctionExpr>().is_some_and(|sf| {
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some()
|| ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(sf).is_some()
|| ScalarFunctionExpr::try_downcast_func::<ArrayLength>(sf).is_some()
})
}

Expand Down Expand Up @@ -572,14 +608,20 @@ fn supported_data_types(dt: &DataType) -> bool {
}

/// Checks if a scalar function can be pushed down.
/// Currently GetFieldFunc and OctetLengthFunc are supported.
/// Currently GetFieldFunc, OctetLengthFunc, and ArrayLength are supported.
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
if ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some() {
return true;
}

ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
if ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
.is_some_and(|octet_length| can_octet_length_be_pushed_down(octet_length, schema))
{
return true;
}

ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
.is_some_and(|array_length| can_array_length_be_pushed_down(array_length, schema))
}

fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
Expand All @@ -598,6 +640,59 @@ fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Sche
}) && can_be_pushed_down_impl(input, schema)
}

fn can_array_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
let Some(input) = array_length_input(scalar_fn) else {
return false;
};

// The argument must resolve to a list type. We gate on the resolved data type rather than
// `can_be_pushed_down_impl`, since list columns are intentionally rejected there. We still
// require the argument to be a convertible expression (e.g. a column or struct field access).
input.data_type(schema).as_ref().is_ok_and(|data_type| {
matches!(
data_type,
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
)
}) && is_convertible_expr(input)
}

/// Returns the list argument of an `array_length` call if the call is a form we can rewrite to
/// `list_length`: either the single-argument form `array_length(arr)`, or the two-argument form
/// with an explicit first dimension `array_length(arr, 1)`, which is equivalent. Higher
/// dimensions recurse into nested lists and are not supported.
fn array_length_input(scalar_fn: &ScalarFunctionExpr) -> Option<&Arc<dyn PhysicalExpr>> {
match scalar_fn.args() {
[input] => Some(input),
[input, dimension] if is_dimension_one(dimension) => Some(input),
_ => None,
}
}

/// Returns true if `expr` is an integer literal equal to 1. The dimension argument of
/// `array_length` is coerced to `Int64`, but we accept any integer width defensively.
fn is_dimension_one(expr: &Arc<dyn PhysicalExpr>) -> bool {
let Some(literal) = expr.downcast_ref::<df_expr::Literal>() else {
return false;
};

let dimension = match literal.value() {
ScalarValue::Int8(Some(v)) => i64::from(*v),
ScalarValue::Int16(Some(v)) => i64::from(*v),
ScalarValue::Int32(Some(v)) => i64::from(*v),
ScalarValue::Int64(Some(v)) => *v,
ScalarValue::UInt8(Some(v)) => i64::from(*v),
ScalarValue::UInt16(Some(v)) => i64::from(*v),
ScalarValue::UInt32(Some(v)) => i64::from(*v),
ScalarValue::UInt64(Some(v)) => match i64::try_from(*v) {
Ok(v) => v,
Err(_) => return false,
},
_ => return false,
};

dimension == 1
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -633,7 +728,7 @@ mod tests {
true,
),
Field::new(
"unsupported_list",
"tags",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
Expand All @@ -652,6 +747,21 @@ mod tests {
)
}

fn array_length_expr(
args: Vec<Arc<dyn PhysicalExpr>>,
schema: &Schema,
) -> Arc<dyn PhysicalExpr> {
Arc::new(
ScalarFunctionExpr::try_new(
Arc::new(ScalarUDF::from(ArrayLength::new())),
args,
schema,
Arc::new(ConfigOptions::new()),
)
.unwrap(),
)
}

#[test]
fn test_make_vortex_predicate_empty() {
let expr_convertor = DefaultExpressionConvertor::default();
Expand Down Expand Up @@ -798,6 +908,23 @@ mod tests {
");
}

#[rstest]
fn test_expr_from_df_array_length(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![expr], &test_schema);

let result = DefaultExpressionConvertor::default()
.convert(array_length.as_ref())
.unwrap();

assert_snapshot!(result.display_tree().to_string(), @r"
vortex.cast(u64?)
└── input: vortex.list.length()
└── input: vortex.get_item(tags)
└── input: vortex.root()
");
}

#[rstest]
// Supported types
#[case::null(DataType::Null, true)]
Expand Down Expand Up @@ -862,7 +989,7 @@ mod tests {
#[rstest]
fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
let col_expr =
Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;

assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
}
Expand Down Expand Up @@ -919,7 +1046,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let left = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let right =
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
Expand All @@ -942,7 +1069,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
"test%".to_string(),
)))) as Arc<dyn PhysicalExpr>;
Expand All @@ -962,7 +1089,7 @@ mod tests {

#[rstest]
fn test_can_be_pushed_down_octet_length_unsupported_operand(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let octet_length = Arc::new(ScalarFunctionExpr::new(
"octet_length",
Arc::new(ScalarUDF::from(OctetLengthFunc::new())),
Expand All @@ -974,6 +1101,52 @@ mod tests {
assert!(!can_be_pushed_down_impl(&octet_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_supported(test_schema: Schema) {
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![expr], &test_schema);

assert!(can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_unsupported_operand(test_schema: Schema) {
// `array_length` over a non-list column cannot be pushed down.
let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
let array_length = Arc::new(ScalarFunctionExpr::new(
"array_length",
Arc::new(ScalarUDF::from(ArrayLength::new())),
vec![expr],
Arc::new(Field::new("array_length", DataType::UInt64, true)),
Arc::new(ConfigOptions::new()),
)) as Arc<dyn PhysicalExpr>;

assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_dimension_one_supported(test_schema: Schema) {
// `array_length(arr, 1)` is the first-dimension length, equivalent to `list_length`.
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let dimension =
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(1)))) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![list, dimension], &test_schema);

assert!(can_be_pushed_down_impl(&array_length, &test_schema));
}

#[rstest]
fn test_can_be_pushed_down_array_length_higher_dimension_not_supported(test_schema: Schema) {
// Dimensions other than 1 recurse into nested lists, which `list_length` does not model,
// so they must not be pushed down.
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
let dimension =
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(2)))) as Arc<dyn PhysicalExpr>;
let array_length = array_length_expr(vec![list, dimension], &test_schema);

assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
}

// https://github.com/vortex-data/vortex/issues/6211
#[tokio::test]
async fn test_cast_int_to_string() -> anyhow::Result<()> {
Expand Down
Loading
Loading