diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 5ef0c5c52ff..98de7caf95d 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -236,6 +236,10 @@ harness = false name = "filter_bool" harness = false +[[bench]] +name = "list_length" +harness = false + [[bench]] name = "listview_rebuild" harness = false diff --git a/vortex-array/benches/list_length.rs b/vortex-array/benches/list_length.rs new file mode 100644 index 00000000000..6148b54f8f9 --- /dev/null +++ b/vortex-array/benches/list_length.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for the `list_length` scalar function over `List` and `ListView` inputs. +//! +//! `list_length` reads only the offsets/sizes (never the elements), so its cost scales with the +//! number of lists. + +#![expect(clippy::unwrap_used)] +#![expect(clippy::cast_possible_truncation)] + +use std::sync::LazyLock; + +use divan::Bencher; +use rand::RngExt; +use rand::SeedableRng; +use rand::distr::Uniform; +use rand::rngs::StdRng; +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ListArray; +use vortex_array::arrays::ListViewArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::expr::list_length; +use vortex_array::expr::root; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +static SESSION: LazyLock = LazyLock::new(vortex_array::array_session); + +const BASE_LIST_SIZE: usize = 8; + +const SMALL: usize = 100; +const MEDIUM: usize = 10_000; +const LARGE: usize = 1_000_000; + +/// A uniformly-random partition of `num_lists * LIST_SIZE` elements into `num_lists` lists, +/// plus a validity mask with ~1/8 of lists null at random positions. +fn random_lists(num_lists: usize) -> (Vec, Validity) { + let mut rng = StdRng::seed_from_u64(num_lists as u64); + let total = (num_lists * BASE_LIST_SIZE) as i32; + + let cut_dist = Uniform::new_inclusive(0i32, total).unwrap(); + let mut cuts: Vec = (0..num_lists - 1).map(|_| rng.sample(cut_dist)).collect(); + cuts.sort_unstable(); + let mut sizes = Vec::with_capacity(num_lists); + let mut prev = 0i32; + for cut in cuts { + sizes.push(cut - prev); + prev = cut; + } + sizes.push(total - prev); + + let null_dist = Uniform::new(0u32, 8).unwrap(); + let valid = (0..num_lists).map(|_| rng.sample(null_dist) != 0); + ( + sizes, + Validity::Array(BoolArray::from_iter(valid).into_array()), + ) +} + +/// A canonical `List` of `num_lists` variable-length lists, ~1/8 of them null. +fn make_list(num_lists: usize) -> ArrayRef { + let (sizes, validity) = random_lists(num_lists); + let total: i32 = sizes.iter().sum(); + let elements = PrimitiveArray::from_iter(0..total).into_array(); + let offsets: Buffer = std::iter::once(0) + .chain(sizes.iter().scan(0i32, |acc, &s| { + *acc += s; + Some(*acc) + })) + .collect(); + ListArray::try_new(elements, offsets.into_array(), validity) + .unwrap() + .into_array() +} + +/// A gapless `ListView` of `num_lists` variable-length lists, ~1/8 of them null. +fn make_listview(num_lists: usize) -> ArrayRef { + let (sizes, validity) = random_lists(num_lists); + let total: i32 = sizes.iter().sum(); + let elements = PrimitiveArray::from_iter(0..total).into_array(); + let offsets: Buffer = sizes + .iter() + .scan(0i32, |acc, &s| { + let start = *acc; + *acc += s; + Some(start) + }) + .collect(); + let sizes: Buffer = sizes.into_iter().collect(); + ListViewArray::new(elements, offsets.into_array(), sizes.into_array(), validity).into_array() +} + +/// Apply `list_length(root())` and materialize the result. +fn run(bencher: Bencher, array: ArrayRef) { + let expr = list_length(root()); + bencher + .with_inputs(|| (&array, SESSION.create_execution_ctx())) + .bench_refs(|(array, ctx)| { + array + .clone() + .apply(&expr) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +#[divan::bench] +fn list_length_small(bencher: Bencher) { + run(bencher, make_list(SMALL)); +} + +#[divan::bench] +fn list_length_medium(bencher: Bencher) { + run(bencher, make_list(MEDIUM)); +} + +#[divan::bench] +fn list_length_large(bencher: Bencher) { + run(bencher, make_list(LARGE)); +} + +#[divan::bench] +fn listview_length_small(bencher: Bencher) { + run(bencher, make_listview(SMALL)); +} + +#[divan::bench] +fn listview_length_medium(bencher: Bencher) { + run(bencher, make_listview(MEDIUM)); +} + +#[divan::bench] +fn listview_length_large(bencher: Bencher) { + run(bencher, make_listview(LARGE)); +} diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index 303424f56c0..ec12a321552 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -37,6 +37,7 @@ use crate::scalar_fn::fns::is_null::IsNull; use crate::scalar_fn::fns::like::Like; use crate::scalar_fn::fns::like::LikeOptions; use crate::scalar_fn::fns::list_contains::ListContains; +use crate::scalar_fn::fns::list_length::ListLength; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::mask::Mask; use crate::scalar_fn::fns::merge::DuplicateHandling; @@ -750,3 +751,17 @@ pub fn byte_length(input: Expression) -> Expression { pub fn ext_storage(input: Expression) -> Expression { ExtStorage.new_expr(EmptyOptions, [input]) } + +// ---- ListLength ---- + +/// Creates an expression that computes the number of elements in each list +/// for `List` and `FixedSizeList` inputs. This is akin to ANSI SQL `CARDINALITY()`, +/// or DuckDB's `len()`/`array_length()`. +/// +/// ```rust +/// # use vortex_array::expr::{list_length, root}; +/// let expr = list_length(root()); +/// ``` +pub fn list_length(input: Expression) -> Expression { + ListLength.new_expr(EmptyOptions, [input]) +} diff --git a/vortex-array/src/scalar_fn/fns/list_length.rs b/vortex-array/src/scalar_fn/fns/list_length.rs new file mode 100644 index 00000000000..ea88a3b27d6 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/list_length.rs @@ -0,0 +1,396 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::AsPrimitive; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_session::VortexSession; +use vortex_session::registry::CachedId; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::ConstantArray; +use crate::arrays::FixedSizeList; +use crate::arrays::List; +use crate::arrays::ListView; +use crate::arrays::fixed_size_list::FixedSizeListArrayExt; +use crate::arrays::list::ListArrayExt; +use crate::arrays::listview::ListViewArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::expr::Expression; +use crate::matcher::Matcher; +use crate::scalar::Scalar; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::operators::Operator; + +/// Number of elements in each list of a `List` or `FixedSizeList` typed array. +/// +/// This is computed purely from the list's offsets (`ListArray`), sizes (`ListViewArray`), or +/// dtype (`FixedSizeListArray`) without reading the element *values*. Validity is carried over +/// from the original array. +#[derive(Clone)] +pub struct ListLength; + +impl ScalarFnVTable for ListLength { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + static ID: CachedId = CachedId::new("vortex.list.length"); + *ID + } + + fn serialize(&self, _instance: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("Invalid child index {child_idx} for list_length()"), + } + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + match &arg_dtypes[0] { + DType::List(_, nullable) | DType::FixedSizeList(_, _, nullable) => { + Ok(DType::Primitive(PType::U64, *nullable)) + } + other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"), + } + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let input = args.get(0)?; + let nullability = input.dtype().nullability(); + + if let Some(scalar) = input.as_constant() { + let len_scalar = scalar_list_length(&scalar, nullability)?; + return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array()); + } + + list_length(&input, nullability, ctx) + } + + fn validity( + &self, + _: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +fn scalar_list_length(scalar: &Scalar, nullability: Nullability) -> VortexResult { + if scalar.is_null() { + let dtype = DType::Primitive(PType::U64, Nullability::Nullable); + return Ok(Scalar::null(dtype)); + } + let len: u64 = scalar.as_list().len().as_(); + Ok(Scalar::primitive(len, nullability)) +} + +pub(crate) fn list_length( + array: &ArrayRef, + nullability: Nullability, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let any_list = array.clone().execute_until::(ctx)?; + + let (lengths, validity) = if let Some(fsl) = any_list.as_opt::() { + // The length of fixed-size list is constant, so just need to carry over validity + let size = fsl.list_size() as u64; + let lengths = + ConstantArray::new(Scalar::primitive(size, Nullability::NonNullable), fsl.len()) + .into_array(); + (lengths, fsl.validity()?) + } else if let Some(lv) = any_list.as_opt::() { + // Length array is exactly the sizes child + (lv.sizes().clone(), lv.listview_validity()) + } else if let Some(l) = any_list.as_opt::() { + let lengths = list_length_from_offsets(l)?; + (lengths, l.list_validity()) + } else { + let dtype = any_list.dtype(); + vortex_bail!("list_length() requires List, ListView, or FixedSizeList but got {dtype}") + }; + + // Cast to `U64` + let len = lengths.len(); + let lengths = lengths.cast(DType::Primitive(PType::U64, nullability))?; + + // Carry over validity mask for nullable arrays + if matches!(nullability, Nullability::Nullable) { + lengths.mask(validity.to_array(len)) + } else { + Ok(lengths) + } +} + +/// Calculate the lengths of `ListArray` elements via the `offsets` child: +/// `length[i] = offsets[i + 1] - offsets[i]`. +fn list_length_from_offsets(list: ArrayView<'_, List>) -> VortexResult { + let offsets = list.offsets(); + let n = offsets.len().saturating_sub(1); + + offsets + .slice(1..offsets.len())? + .binary(offsets.slice(0..n)?, Operator::Sub) +} + +/// Matches an `Array`, `Array`, or `Array` +struct AnyList; + +impl Matcher for AnyList { + type Match<'a> = (); + + fn try_match(array: &ArrayRef) -> Option> { + (array.as_opt::().is_some() + || array.as_opt::().is_some() + || array.as_opt::().is_some()) + .then_some(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::VortexSessionExecute; + use crate::array_session; + use crate::arrays::BoolArray; + use crate::arrays::ConstantArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::cast; + use crate::expr::list_length; + use crate::expr::root; + use crate::scalar::Scalar; + use crate::validity::Validity; + + fn create_list_elements() -> ArrayRef { + PrimitiveArray::from_option_iter::([ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + None, + ]) + .into_array() + } + + #[rstest] + #[case(buffer![0u32, 2, 5, 5, 7].into_array())] + #[case(buffer![0u64, 2, 5, 5, 7].into_array())] + fn test_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new(elements, offsets, Validity::NonNullable)?.into_array(); + let result = list.apply(&list_length(root()))?; + let mut ctx = array_session().create_execution_ctx(); + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx); + Ok(()) + } + + #[rstest] + #[case(buffer![0u32, 2, 5, 5, 7].into_array())] + #[case(buffer![0u64, 2, 5, 5, 7].into_array())] + fn test_nullable_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new( + elements, + offsets, + Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()), + )? + .into_array(); + let result = list.apply(&list_length(root()))?; + + let mut ctx = array_session().create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(0), None]); + + assert_arrays_eq!(result, expected, &mut ctx); + + Ok(()) + } + + #[test] + fn test_null_scalar_list_length() -> VortexResult<()> { + let null_scalar = Scalar::null(DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), + Nullability::Nullable, + )); + let array = ConstantArray::new(null_scalar, 2).into_array(); + let result = array.apply(&list_length(root()))?; + + let mut ctx = array_session().create_execution_ctx(); + assert!(!result.is_valid(0, &mut ctx)?); + assert!(!result.is_valid(1, &mut ctx)?); + Ok(()) + } + + #[test] + fn test_listview_length() -> VortexResult<()> { + let elements = create_list_elements(); + let lv = ListViewArray::new( + elements, + buffer![5u32, 0, 4, 1].into_array(), + buffer![2u32, 3, 0, 2].into_array(), + Validity::NonNullable, + ) + .into_array(); + let result = lv.apply(&list_length(root()))?; + let mut ctx = array_session().create_execution_ctx(); + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx); + Ok(()) + } + + #[test] + fn test_listview_length_nullable() -> VortexResult<()> { + let elements = create_list_elements(); + let lv = ListViewArray::new( + elements, + buffer![5u32, 0, 4, 1].into_array(), + buffer![2u32, 3, 0, 2].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()), + ) + .into_array(); + let result = lv.apply(&list_length(root()))?; + + let mut ctx = array_session().create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(0), None]); + assert_arrays_eq!(result, expected, &mut ctx); + Ok(()) + } + + #[test] + fn test_list_length_take() -> VortexResult<()> { + let elements = create_list_elements(); + let list = ListArray::try_new( + elements, + buffer![0u32, 2, 5, 5, 7].into_array(), + Validity::NonNullable, + )? + .into_array(); + let taken = list.take(buffer![3u64, 0, 2].into_array())?; + + let result = taken.apply(&list_length(root()))?; + let mut ctx = array_session().create_execution_ctx(); + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 0]), &mut ctx); + Ok(()) + } + + fn create_fixed_size_list(validity: Validity) -> ArrayRef { + // 4 lists of size 2 over 8 primitive elements. + let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8]).into_array(); + FixedSizeListArray::new(elements, 2, validity, 4).into_array() + } + + #[test] + fn test_fixed_size_list_length() -> VortexResult<()> { + let fsl = create_fixed_size_list(Validity::NonNullable); + let result = fsl.apply(&list_length(root()))?; + + let mut ctx = array_session().create_execution_ctx(); + assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 2, 2]), &mut ctx); + Ok(()) + } + + #[test] + fn test_fixed_size_list_length_nullable() -> VortexResult<()> { + let fsl = create_fixed_size_list(Validity::Array( + BoolArray::from_iter([true, false, true, false]).into_array(), + )); + let result = fsl.apply(&list_length(root()))?; + + let mut ctx = array_session().create_execution_ctx(); + let result = result.execute::(&mut ctx)?; + + let expected = PrimitiveArray::from_option_iter::([Some(2), None, Some(2), None]); + assert_arrays_eq!(result, expected, &mut ctx); + Ok(()) + } + + #[test] + fn test_fallible_child_expression_fails() -> VortexResult<()> { + let fsl = create_fixed_size_list(Validity::Array( + BoolArray::from_iter([true, false, true, false]).into_array(), + )); + let failing_cast_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), + 2, + Nullability::NonNullable, + ); + + let lengths = fsl.apply(&list_length(cast(root(), failing_cast_dtype)))?; + + let mut ctx = array_session().create_execution_ctx(); + let result = lengths.execute::(&mut ctx); + + assert!(result.is_err()); + + let err_message = result.unwrap_err().to_string(); + + assert!( + err_message.contains("Cannot cast array with invalid values to non-nullable type.") + ); + + Ok(()) + } + + #[test] + fn test_display() { + let expr = list_length(root()); + assert_eq!(expr.to_string(), "vortex.list.length($)"); + } +} diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 41843d0a18e..f7e98e676b0 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -14,6 +14,7 @@ pub mod is_not_null; pub mod is_null; pub mod like; pub mod list_contains; +pub mod list_length; pub mod literal; pub mod mask; pub mod merge; diff --git a/vortex-array/src/scalar_fn/session.rs b/vortex-array/src/scalar_fn/session.rs index b85120e5345..2109b3cd3ce 100644 --- a/vortex-array/src/scalar_fn/session.rs +++ b/vortex-array/src/scalar_fn/session.rs @@ -20,6 +20,7 @@ use crate::scalar_fn::fns::is_not_null::IsNotNull; use crate::scalar_fn::fns::is_null::IsNull; use crate::scalar_fn::fns::like::Like; use crate::scalar_fn::fns::list_contains::ListContains; +use crate::scalar_fn::fns::list_length::ListLength; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::merge::Merge; use crate::scalar_fn::fns::not::Not; @@ -67,6 +68,7 @@ impl Default for ScalarFnSession { this.register(IsNull); this.register(Like); this.register(ListContains); + this.register(ListLength); this.register(Literal); this.register(Merge); this.register(Not);