diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 8d6b7e00796..6bcc7d470d8 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -232,6 +232,10 @@ harness = false name = "filter_bool" harness = false +[[bench]] +name = "list_length" +harness = false + [[bench]] name = "listview_rebuild" harness = false diff --git a/vortex-array/benches/list_length.rs b/vortex-array/benches/list_length.rs new file mode 100644 index 00000000000..60914a15712 --- /dev/null +++ b/vortex-array/benches/list_length.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for the `list_length` scalar function over `List` and `ListView` inputs. +//! +//! `list_length` reads only the offsets/sizes (never the elements), so its cost scales with the +//! number of lists. + +#![expect(clippy::unwrap_used)] +#![expect(clippy::cast_possible_truncation)] + +use std::sync::LazyLock; + +use divan::Bencher; +use rand::RngExt; +use rand::SeedableRng; +use rand::distr::Uniform; +use rand::rngs::StdRng; +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ListArray; +use vortex_array::arrays::ListViewArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::expr::list_length; +use vortex_array::expr::root; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +static SESSION: LazyLock = LazyLock::new(vortex_array::array_session); + +const BASE_LIST_SIZE: usize = 8; + +const SMALL: usize = 100; +const MEDIUM: usize = 10_000; +const LARGE: usize = 1_000_000; + +/// A uniformly-random partition of `num_lists * LIST_SIZE` elements into `num_lists` lists, +/// plus a validity mask with ~1/8 of lists null at random positions. +fn random_lists(num_lists: usize) -> (Vec, Validity) { + let mut rng = StdRng::seed_from_u64(num_lists as u64); + let total = (num_lists * BASE_LIST_SIZE) as i32; + + let cut_dist = Uniform::new_inclusive(0i32, total).unwrap(); + let mut cuts: Vec = (0..num_lists - 1).map(|_| rng.sample(cut_dist)).collect(); + cuts.sort_unstable(); + let mut sizes = Vec::with_capacity(num_lists); + let mut prev = 0i32; + for cut in cuts { + sizes.push(cut - prev); + prev = cut; + } + sizes.push(total - prev); + + let null_dist = Uniform::new(0u32, 8).unwrap(); + let valid = (0..num_lists).map(|_| rng.sample(null_dist) != 0); + ( + sizes, + Validity::Array(BoolArray::from_iter(valid).into_array()), + ) +} + +/// A canonical `List` of `num_lists` variable-length lists, ~1/8 of them null. +fn make_list(num_lists: usize) -> ArrayRef { + let (sizes, validity) = random_lists(num_lists); + let total: i32 = sizes.iter().sum(); + let elements = PrimitiveArray::from_iter(0..total).into_array(); + let offsets: Buffer = std::iter::once(0) + .chain(sizes.iter().scan(0i32, |acc, &s| { + *acc += s; + Some(*acc) + })) + .collect(); + ListArray::try_new(elements, offsets.into_array(), validity) + .unwrap() + .into_array() +} + +/// A gapless `ListView` of `num_lists` variable-length lists, ~1/8 of them null. +fn make_listview(num_lists: usize) -> ArrayRef { + let (sizes, validity) = random_lists(num_lists); + let total: i32 = sizes.iter().sum(); + let elements = PrimitiveArray::from_iter(0..total).into_array(); + let offsets: Buffer = sizes + .iter() + .scan(0i32, |acc, &s| { + let start = *acc; + *acc += s; + Some(start) + }) + .collect(); + let sizes: Buffer = sizes.into_iter().collect(); + ListViewArray::new(elements, offsets.into_array(), sizes.into_array(), validity).into_array() +} + +/// Apply `list_length(root())` and materialize the result. +fn run(bencher: Bencher, array: ArrayRef) { + let expr = list_length(root()); + bencher + .with_inputs(|| (&array, SESSION.create_execution_ctx())) + .bench_refs(|(array, ctx)| { + array + .clone() + .apply(&expr) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +#[divan::bench] +fn list_small(bencher: Bencher) { + run(bencher, make_list(SMALL)); +} + +#[divan::bench] +fn list_medium(bencher: Bencher) { + run(bencher, make_list(MEDIUM)); +} + +#[divan::bench] +fn list_large(bencher: Bencher) { + run(bencher, make_list(LARGE)); +} + +#[divan::bench] +fn listview_small(bencher: Bencher) { + run(bencher, make_listview(SMALL)); +} + +#[divan::bench] +fn listview_medium(bencher: Bencher) { + run(bencher, make_listview(MEDIUM)); +} + +#[divan::bench] +fn listview_large(bencher: Bencher) { + run(bencher, make_listview(LARGE)); +} diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index ab3a115c652..434f25fa616 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -37,6 +37,7 @@ use crate::scalar_fn::fns::is_null::IsNull; use crate::scalar_fn::fns::like::Like; use crate::scalar_fn::fns::like::LikeOptions; use crate::scalar_fn::fns::list_contains::ListContains; +use crate::scalar_fn::fns::list_length::ListLength; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::mask::Mask; use crate::scalar_fn::fns::merge::DuplicateHandling; @@ -750,3 +751,17 @@ pub fn byte_length(input: Expression) -> Expression { pub fn ext_storage(input: Expression) -> Expression { ExtStorage.new_expr(EmptyOptions, [input]) } + +// ---- ListLength ---- + +/// Creates an expression that computes the number of elements in each list +/// for `List` and `FixedSizeList` inputs. This is akin to ANSI SQL `CARDINALITY()`, +/// or DuckDB's `len()`/`array_length()`. +/// +/// ```rust +/// # use vortex_array::expr::{list_length, root}; +/// let expr = list_length(root()); +/// ``` +pub fn list_length(input: Expression) -> Expression { + ListLength.new_expr(EmptyOptions, [input]) +} diff --git a/vortex-array/src/scalar_fn/fns/list_length.rs b/vortex-array/src/scalar_fn/fns/list_length.rs new file mode 100644 index 00000000000..2af0260e881 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/list_length.rs @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::AsPrimitive; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_session::VortexSession; +use vortex_session::registry::CachedId; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::ConstantArray; +use crate::arrays::List; +use crate::arrays::ListView; +use crate::arrays::list::ListArrayExt; +use crate::arrays::listview::ListViewArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::expr::Expression; +use crate::matcher::Matcher; +use crate::scalar::Scalar; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ReduceCtx; +use crate::scalar_fn::ReduceNode; +use crate::scalar_fn::ReduceNodeRef; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::literal::Literal; +use crate::scalar_fn::fns::operators::Operator; + +/// Number of elements in each list of a `List` or `FixedSizeList` typed array. +/// +/// This is computed purely from the list's offsets (`ListArray`), sizes (`ListViewArray`), or +/// dtype (`FixedSizeListArray`) without reading the element *values*. Validity is carried over +/// from the original array. +#[derive(Clone)] +pub struct ListLength; + +impl ScalarFnVTable for ListLength { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + static ID: CachedId = CachedId::new("vortex.list.length"); + *ID + } + + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("Invalid child index {child_idx} for list_length()"), + } + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + match &arg_dtypes[0] { + DType::List(_, nullable) | DType::FixedSizeList(_, _, nullable) => { + Ok(DType::Primitive(PType::U64, *nullable)) + } + other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"), + } + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let input = args.get(0)?; + let nullability = input.dtype().nullability(); + + if let Some(scalar) = input.as_constant() { + let len_scalar = scalar_list_length(&scalar, nullability)?; + return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array()); + } + + list_length(&input, nullability, ctx) + } + + fn reduce( + &self, + _options: &Self::Options, + node: &dyn ReduceNode, + ctx: &dyn ReduceCtx, + ) -> VortexResult> { + // The length of nonnullable fixed-size list is constant + if let DType::FixedSizeList(_, size, Nullability::NonNullable) = + node.child(0).node_dtype()? + { + let length = Scalar::primitive(size as u64, Nullability::NonNullable); + return Ok(Some(ctx.new_node(Literal.bind(length), &[])?)); + } + Ok(None) + } + + fn validity( + &self, + _: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +fn scalar_list_length(scalar: &Scalar, nullability: Nullability) -> VortexResult { + if scalar.is_null() { + let dtype = DType::Primitive(PType::U64, Nullability::Nullable); + return Ok(Scalar::null(dtype)); + } + let len: u64 = scalar.as_list().len().as_(); + Ok(Scalar::primitive(len, nullability)) +} + +pub(crate) fn list_length( + array: &ArrayRef, + nullability: Nullability, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let (lengths, validity) = match array.dtype() { + // The length of fixed-size list is constant, so just need to carry over validity + DType::FixedSizeList(_, size, _) => { + let lengths = ConstantArray::new( + Scalar::primitive(*size as u64, Nullability::NonNullable), + array.len(), + ) + .into_array(); + (lengths, array.validity()?) + } + DType::List(..) => { + let list = array.clone().execute_until::(ctx)?; + + if let Some(list) = list.as_opt::() { + let lengths = list_length_from_offsets(list)?; + (lengths, list.list_validity()) + } else if let Some(list_view) = list.as_opt::() { + // Length array is exactly the sizes child + (list_view.sizes().clone(), list_view.listview_validity()) + } else { + unreachable!("AnyList matcher guarantees List or ListView") + } + } + other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"), + }; + + // Cast to `U64` + let len = lengths.len(); + let lengths = lengths.cast(DType::Primitive(PType::U64, nullability))?; + + // Carry over validity mask for nullable arrays + if matches!(nullability, Nullability::Nullable) { + lengths.mask(validity.to_array(len)) + } else { + Ok(lengths) + } +} + +/// Calculate the lengths of `ListArray` elements via the `offsets` child: +/// `length[i] = offsets[i + 1] - offsets[i]`. +fn list_length_from_offsets(list: ArrayView<'_, List>) -> VortexResult { + let offsets = list.offsets(); + let n = offsets.len().saturating_sub(1); + + offsets + .slice(1..offsets.len())? + .binary(offsets.slice(0..n)?, Operator::Sub) +} + +/// Matches an `Array` or `Array`. +struct AnyList; + +impl Matcher for AnyList { + type Match<'a> = (); + + fn try_match(array: &ArrayRef) -> Option> { + (array.as_opt::().is_some() || array.as_opt::().is_some()).then_some(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ConstantArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::ScalarFn; + use crate::arrays::scalar_fn::ScalarFnArrayExt; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::list_length; + use crate::expr::root; + use crate::scalar::Scalar; + use crate::scalar_fn::fns::literal::Literal; + use crate::validity::Validity; + + fn create_list_elements() -> ArrayRef { + PrimitiveArray::from_option_iter::([ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + None, + ]) + .into_array() + } + + #[rstest] + #[case(buffer![0u32, 2, 5, 5, 7].into_array())] + #[case(buffer![0u64, 2, 5, 5, 7].into_array())] + fn test_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new(elements, offsets, Validity::NonNullable)?.into_array(); + let result = list.apply(&list_length(root()))?; + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2])); + Ok(()) + } + + #[rstest] + #[case(buffer![0u32, 2, 5, 5, 7].into_array())] + #[case(buffer![0u64, 2, 5, 5, 7].into_array())] + fn test_nullable_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new( + elements, + offsets, + Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()), + )? + .into_array(); + let result = list.apply(&list_length(root()))?; + + let session = VortexSession::empty(); + let mut ctx = session.create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(0), None]); + + assert_arrays_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_null_scalar_list_length() -> VortexResult<()> { + let null_scalar = Scalar::null(DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), + Nullability::Nullable, + )); + let array = ConstantArray::new(null_scalar, 2).into_array(); + let result = array.apply(&list_length(root()))?; + + let session = VortexSession::empty(); + let mut ctx = session.create_execution_ctx(); + assert!(!result.is_valid(0, &mut ctx)?); + assert!(!result.is_valid(1, &mut ctx)?); + Ok(()) + } + + #[test] + fn test_listview_length() -> VortexResult<()> { + let elements = create_list_elements(); + let lv = ListViewArray::new( + elements, + buffer![5u32, 0, 4, 1].into_array(), + buffer![2u32, 3, 0, 2].into_array(), + Validity::NonNullable, + ) + .into_array(); + let result = lv.apply(&list_length(root()))?; + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2])); + Ok(()) + } + + #[test] + fn test_listview_length_nullable() -> VortexResult<()> { + let elements = create_list_elements(); + let lv = ListViewArray::new( + elements, + buffer![5u32, 0, 4, 1].into_array(), + buffer![2u32, 3, 0, 2].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()), + ) + .into_array(); + let result = lv.apply(&list_length(root()))?; + + let session = VortexSession::empty(); + let mut ctx = session.create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(0), None]); + assert_arrays_eq!(result, expected); + Ok(()) + } + + #[test] + fn test_list_length_take() -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new( + elements, + buffer![0u32, 2, 5, 5, 7].into_array(), + Validity::NonNullable, + )? + .into_array(); + let taken = list.take(buffer![3u64, 0, 2].into_array())?; + + let result = taken.apply(&list_length(root()))?; + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 0])); + Ok(()) + } + + fn create_fixed_size_list(validity: Validity) -> ArrayRef { + // 4 lists of size 2 over 8 primitive elements. + let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8]).into_array(); + FixedSizeListArray::new(elements, 2, validity, 4).into_array() + } + + #[test] + fn test_fixed_size_list_length() -> VortexResult<()> { + let fsl = create_fixed_size_list(Validity::NonNullable); + let result = fsl.apply(&list_length(root()))?; + + // A non-nullable fixed-size list reduces to a constant literal length, never touching the + // `ListLength` execution path. + assert!( + result + .as_opt::() + .is_some_and(|f| f.scalar_fn().as_opt::().is_some()), + "list_length over a non-nullable FixedSizeList must reduce to a constant literal" + ); + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 2, 2])); + Ok(()) + } + + #[test] + fn test_fixed_size_list_length_nullable() -> VortexResult<()> { + let fsl = create_fixed_size_list(Validity::Array( + BoolArray::from_iter([true, false, true, false]).into_array(), + )); + let result = fsl.apply(&list_length(root()))?; + + let session = VortexSession::empty(); + let mut ctx = session.create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(2), None]); + assert_arrays_eq!(result, expected); + Ok(()) + } + + #[test] + fn test_display() { + let expr = list_length(root()); + assert_eq!(expr.to_string(), "vortex.list.length($)"); + } +} diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 41843d0a18e..f7e98e676b0 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -14,6 +14,7 @@ pub mod is_not_null; pub mod is_null; pub mod like; pub mod list_contains; +pub mod list_length; pub mod literal; pub mod mask; pub mod merge; diff --git a/vortex-array/src/scalar_fn/session.rs b/vortex-array/src/scalar_fn/session.rs index b85120e5345..2109b3cd3ce 100644 --- a/vortex-array/src/scalar_fn/session.rs +++ b/vortex-array/src/scalar_fn/session.rs @@ -20,6 +20,7 @@ use crate::scalar_fn::fns::is_not_null::IsNotNull; use crate::scalar_fn::fns::is_null::IsNull; use crate::scalar_fn::fns::like::Like; use crate::scalar_fn::fns::list_contains::ListContains; +use crate::scalar_fn::fns::list_length::ListLength; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::merge::Merge; use crate::scalar_fn::fns::not::Not; @@ -67,6 +68,7 @@ impl Default for ScalarFnSession { this.register(IsNull); this.register(Like); this.register(ListContains); + this.register(ListLength); this.register(Literal); this.register(Merge); this.register(Not); diff --git a/vortex-duckdb/cpp/expr.cpp b/vortex-duckdb/cpp/expr.cpp index b15ce8e02cd..616b7aa4f88 100644 --- a/vortex-duckdb/cpp/expr.cpp +++ b/vortex-duckdb/cpp/expr.cpp @@ -38,6 +38,14 @@ extern "C" duckdb_vx_expr_class duckdb_vx_expr_get_class(duckdb_vx_expr ffi_expr return static_cast(expr->GetExpressionClass()); } +extern "C" duckdb_logical_type duckdb_vx_expr_get_return_type(duckdb_vx_expr ffi_expr) { + if (!ffi_expr) { + return nullptr; + } + auto expr = reinterpret_cast(ffi_expr); + return reinterpret_cast(&expr->return_type); +} + extern "C" const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_expr ffi_expr) { if (!ffi_expr) { return nullptr; diff --git a/vortex-duckdb/cpp/include/duckdb_vx/expr.h b/vortex-duckdb/cpp/include/duckdb_vx/expr.h index b1344341403..5e6277a4fa8 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/expr.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/expr.h @@ -209,6 +209,10 @@ typedef enum DUCKDB_VX_EXPR_TYPE { duckdb_vx_expr_class duckdb_vx_expr_get_class(duckdb_vx_expr expr); +/// Return the (bound) return type of the expression. The logical type is borrowed from the +/// expression and must not be freed. +duckdb_logical_type duckdb_vx_expr_get_return_type(duckdb_vx_expr expr); + const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_expr expr); duckdb_value duckdb_vx_expr_bound_constant_get_value(duckdb_vx_expr expr); diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 324086e5775..6712d80e972 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -22,6 +22,7 @@ use vortex::expr::get_item; use vortex::expr::is_not_null; use vortex::expr::is_null; use vortex::expr::list_contains; +use vortex::expr::list_length; use vortex::expr::lit; use vortex::expr::not; use vortex::expr::or_collect; @@ -37,6 +38,7 @@ use vortex::scalar_fn::fns::like::LikeOptions; use vortex::scalar_fn::fns::literal::Literal; use vortex::scalar_fn::fns::operators::Operator; +use crate::cpp::DUCKDB_TYPE; use crate::cpp::DUCKDB_VX_EXPR_TYPE; use crate::duckdb; use crate::duckdb::BoundFunction; @@ -57,6 +59,15 @@ fn from_bound_str(value: &duckdb::ExpressionRef) -> VortexResult { } } +/// Returns true if the expression's bound return type is a `LIST` or fixed-size `ARRAY`. Used to +/// disambiguate the overloaded `len`/`length` functions, which also apply to strings and bits. +fn is_list_typed(expr: &duckdb::ExpressionRef) -> bool { + matches!( + expr.return_type().as_type_id(), + DUCKDB_TYPE::DUCKDB_TYPE_LIST | DUCKDB_TYPE::DUCKDB_TYPE_ARRAY + ) +} + fn try_from_bound_function( func: &BoundFunction, col_sub: Option<&Expression>, @@ -115,6 +126,25 @@ fn try_from_bound_function( }; Like.new_expr(LikeOptions::default(), [value, lit(pattern)]) } + // `array_length` is list-only, but `len`/`length` are also defined for strings and bits, + // so we gate on the argument's bound return type being a list or fixed-size array. + "len" | "length" | "array_length" => { + let children: Vec<_> = func.children().collect(); + // Only the single-argument form maps to list_length; the two-argument + // (dimension) form of array_length has different semantics. + if children.len() != 1 || !is_list_typed(children[0]) { + return Ok(None); + } + let Some(col) = try_from_expression_inner(children[0], col_sub)? else { + return Ok(None); + }; + let col = list_length(col); + // list_length returns u64, len()/array_length() return i64. We don't know the + // column's nullability here, so we set it to nullable; for a non-nullable + // column the validity will be AllValid so it's a marginal cost. + let dtype = DType::Primitive(PType::I64, Nullability::Nullable); + cast(col, dtype) + } _ => { debug!("bound function {}", func.scalar_function.name()); return Ok(None); @@ -173,6 +203,10 @@ pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { || name == "~~" || name == "!~~" || name == "strlen" + || (matches!(name, "len" | "length" | "array_length") && { + let children: Vec<_> = func.children().collect(); + children.len() == 1 && is_list_typed(children[0]) + }) } ExpressionClass::BoundOperator(op) => { if !matches!( @@ -208,6 +242,18 @@ pub fn try_from_projection_expression( let col = cast(col, dtype); Some(col) } + // `len`/`length`/`array_length` over a list column. The column dtype is known here, so + // unlike the filter path we can safely accept the overloaded `len`/`length` names by + // gating on the field being a list. Only the single-argument form is supported. + "len" | "length" | "array_length" + if matches!(field.dtype, DType::List(..) | DType::FixedSizeList(..)) + && func.children().count() == 1 => + { + let col = list_length(get_item(field.name.as_str(), root())); + // list_length returns u64, len()/array_length() return i64 (BIGINT). + let dtype = DType::Primitive(PType::I64, field.dtype.nullability()); + Some(cast(col, dtype)) + } _ => None, }) } diff --git a/vortex-duckdb/src/duckdb/expr.rs b/vortex-duckdb/src/duckdb/expr.rs index 2b206bc192f..0c1128c9131 100644 --- a/vortex-duckdb/src/duckdb/expr.rs +++ b/vortex-duckdb/src/duckdb/expr.rs @@ -10,6 +10,8 @@ use std::ptr; use crate::cpp; use crate::cpp::duckdb_vx_expr_class; use crate::duckdb::DDBString; +use crate::duckdb::LogicalType; +use crate::duckdb::LogicalTypeRef; use crate::duckdb::ScalarFunction; use crate::duckdb::ScalarFunctionRef; use crate::duckdb::Value; @@ -33,6 +35,11 @@ impl ExpressionRef { unsafe { cpp::duckdb_vx_expr_get_class(self.as_ptr()) } } + /// The (bound) return type of this expression. Borrowed from the expression. + pub fn return_type(&self) -> &LogicalTypeRef { + unsafe { LogicalType::borrow(cpp::duckdb_vx_expr_get_return_type(self.as_ptr())) } + } + /// Match the subclass of the expression. pub fn as_class(&self) -> Option> { Some( diff --git a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs index 99bab3f71ce..706bbd37bf3 100644 --- a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs +++ b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs @@ -992,3 +992,101 @@ fn test_geometry() { let area = vec.as_slice_with_len::(chunk.len().as_())[0]; assert_eq!(area, 1000.0); } + +/// `SELECT array_length(list)` / `len(list)` / `length(list)` should push the list-length +/// computation into the Vortex scan (computed from offsets, without materializing the list +/// elements) and return the per-row element counts. +#[test] +fn test_vortex_scan_list_length_projection() { + let file = RUNTIME.block_on(async { + let integers = PrimitiveArray::from_iter([ + 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, + ]); + // Variable-length lists with 3, 4, 1, 5, 2 elements respectively. + let offsets = buffer![0i32, 3, 7, 8, 13, 15]; + let list_array = ListArray::try_new( + integers.into_array(), + offsets.into_array(), + Validity::AllValid, + ) + .unwrap(); + + write_single_column_vortex_file("int_list", list_array).await + }); + + let conn = database_connection(); + let file_path = file.path().to_string_lossy(); + + // `len`/`length` bind to the same DuckDB function set as `array_length` for list arguments. + for func in ["array_length", "len", "length"] { + let result = conn + .query(&format!("SELECT {func}(int_list) FROM '{file_path}'")) + .unwrap(); + + let mut lengths = Vec::new(); + for chunk in result { + let len = chunk.len().as_(); + let vec = chunk.get_vector(0); + lengths.extend_from_slice(vec.as_slice_with_len::(len)); + } + + assert_eq!(lengths, vec![3, 4, 1, 5, 2], "{func}(int_list) mismatch"); + } +} + +/// `WHERE array_length(list) >= k` should push down as a complex filter. +#[test] +fn test_vortex_scan_list_length_filter() { + let file = RUNTIME.block_on(async { + let integers = PrimitiveArray::from_iter([ + 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, + ]); + // Variable-length lists with 3, 4, 1, 5, 2 elements respectively. + let offsets = buffer![0i32, 3, 7, 8, 13, 15]; + let list_array = ListArray::try_new( + integers.into_array(), + offsets.into_array(), + Validity::AllValid, + ) + .unwrap(); + + write_single_column_vortex_file("int_list", list_array).await + }); + + // Lists with length >= 4: the 4-element and 5-element lists => 2 rows. + let count = scan_vortex_file_single_row::( + file, + "SELECT COUNT(*) FROM ? WHERE array_length(int_list) >= 4", + 0, + ); + assert_eq!(count, 2); +} + +/// `array_length`/`len`/`length` over a FixedSizeList column. The length is the fixed list size. +#[test] +fn test_vortex_scan_fixed_size_list_length_projection() { + let file = RUNTIME.block_on(async { + // 6 fixed-size lists of 4 i32 elements each. + let elements = (0..24i32).collect::(); + let fsl = FixedSizeListArray::new(elements.into_array(), 4, Validity::AllValid, 6); + write_single_column_vortex_file("int_lists", fsl).await + }); + + let conn = database_connection(); + let file_path = file.path().to_string_lossy(); + + for func in ["array_length", "len", "length"] { + let result = conn + .query(&format!("SELECT {func}(int_lists) FROM '{file_path}'")) + .unwrap(); + + let mut lengths = Vec::new(); + for chunk in result { + let len = chunk.len().as_(); + let vec = chunk.get_vector(0); + lengths.extend_from_slice(vec.as_slice_with_len::(len)); + } + + assert_eq!(lengths, vec![4i64; 6], "{func}(int_lists) mismatch"); + } +}