From ec42e758ce143df037ff9bcccc20c50eb8db098e Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 8 May 2026 23:16:02 -0400 Subject: [PATCH] Add stats expression Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 198 +++++++++++++++++++++++++ vortex-array/src/scalar_fn/fns/mod.rs | 1 + vortex-array/src/scalar_fn/fns/stat.rs | 141 ++++++++++++++++++ vortex-array/src/scalar_fn/session.rs | 2 + vortex-array/src/stats/expr.rs | 123 +++++++++++++++ vortex-array/src/stats/legacy.rs | 24 +++ vortex-array/src/stats/mod.rs | 3 + 7 files changed, 492 insertions(+) create mode 100644 vortex-array/src/scalar_fn/fns/stat.rs create mode 100644 vortex-array/src/stats/expr.rs create mode 100644 vortex-array/src/stats/legacy.rs diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8d5cc75c895..989dbd0edc8 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -17800,6 +17800,84 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, & pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> +pub mod vortex_array::scalar_fn::fns::stat + +pub struct vortex_array::scalar_fn::fns::stat::StatFn + +impl core::clone::Clone for vortex_array::scalar_fn::fns::stat::StatFn + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::clone(&self) -> vortex_array::scalar_fn::fns::stat::StatFn + +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::stat::StatFn + +pub type vortex_array::scalar_fn::fns::stat::StatFn::Options = vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::arity(&self, &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::coerce_args(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::ExecutionArgs, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::fmt_sql(&self, &Self::Options, &vortex_array::expr::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::reduce(&self, &Self::Options, &dyn vortex_array::scalar_fn::ReduceNode, &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify_untyped(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_expression(&self, &Self::Options, &vortex_array::expr::Expression, vortex_array::expr::stats::Stat, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_falsification(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::validity(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub struct vortex_array::scalar_fn::fns::stat::StatOptions + +impl vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::aggregate_fn(&self) -> &vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::new(vortex_array::aggregate_fn::AggregateFnRef) -> Self + +impl core::clone::Clone for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::clone(&self) -> vortex_array::scalar_fn::fns::stat::StatOptions + +impl core::cmp::Eq for vortex_array::scalar_fn::fns::stat::StatOptions + +impl core::cmp::PartialEq for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::eq(&self, &vortex_array::scalar_fn::fns::stat::StatOptions) -> bool + +impl core::fmt::Debug for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_array::scalar_fn::fns::stat::StatOptions + pub mod vortex_array::scalar_fn::fns::zip pub struct vortex_array::scalar_fn::fns::zip::Zip @@ -19066,6 +19144,44 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, & pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::stat::StatFn + +pub type vortex_array::scalar_fn::fns::stat::StatFn::Options = vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::arity(&self, &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::coerce_args(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::ExecutionArgs, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::fmt_sql(&self, &Self::Options, &vortex_array::expr::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::reduce(&self, &Self::Options, &dyn vortex_array::scalar_fn::ReduceNode, &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify_untyped(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_expression(&self, &Self::Options, &vortex_array::expr::Expression, vortex_array::expr::stats::Stat, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_falsification(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::validity(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::zip::Zip pub type vortex_array::scalar_fn::fns::zip::Zip::Options = vortex_array::scalar_fn::EmptyOptions @@ -19450,6 +19566,86 @@ pub type vortex_array::session::ArrayRegistry = vortex_session::registry::Regist pub mod vortex_array::stats +pub mod vortex_array::stats::expr + +pub struct vortex_array::stats::expr::StatFn + +impl core::clone::Clone for vortex_array::scalar_fn::fns::stat::StatFn + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::clone(&self) -> vortex_array::scalar_fn::fns::stat::StatFn + +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::stat::StatFn + +pub type vortex_array::scalar_fn::fns::stat::StatFn::Options = vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::arity(&self, &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::coerce_args(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::ExecutionArgs, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::fmt_sql(&self, &Self::Options, &vortex_array::expr::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::reduce(&self, &Self::Options, &dyn vortex_array::scalar_fn::ReduceNode, &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::simplify_untyped(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_expression(&self, &Self::Options, &vortex_array::expr::Expression, vortex_array::expr::stats::Stat, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::stat_falsification(&self, &Self::Options, &vortex_array::expr::Expression, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stat::StatFn::validity(&self, &Self::Options, &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub struct vortex_array::stats::expr::StatOptions + +impl vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::aggregate_fn(&self) -> &vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::new(vortex_array::aggregate_fn::AggregateFnRef) -> Self + +impl core::clone::Clone for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::clone(&self) -> vortex_array::scalar_fn::fns::stat::StatOptions + +impl core::cmp::Eq for vortex_array::scalar_fn::fns::stat::StatOptions + +impl core::cmp::PartialEq for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::eq(&self, &vortex_array::scalar_fn::fns::stat::StatOptions) -> bool + +impl core::fmt::Debug for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::scalar_fn::fns::stat::StatOptions::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_array::scalar_fn::fns::stat::StatOptions + +pub fn vortex_array::stats::expr::stat(vortex_array::expr::Expression, vortex_array::aggregate_fn::AggregateFnRef) -> vortex_array::expr::Expression + pub mod vortex_array::stats::flatbuffers pub struct vortex_array::stats::ArrayStats @@ -19682,6 +19878,8 @@ pub const vortex_array::stats::PRUNING_STATS: &[vortex_array::expr::stats::Stat] pub fn vortex_array::stats::as_stat_bitset_bytes(&[vortex_array::expr::stats::Stat]) -> alloc::vec::Vec +pub fn vortex_array::stats::stat(vortex_array::expr::Expression, vortex_array::aggregate_fn::AggregateFnRef) -> vortex_array::expr::Expression + pub fn vortex_array::stats::stats_from_bitset_bytes(&[u8]) -> alloc::vec::Vec pub type vortex_array::stats::StatsArray = [(vortex_array::expr::stats::Stat, vortex_array::expr::stats::Precision); 4] diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 8fa1b66532d..5ed76395394 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -20,4 +20,5 @@ pub mod operators; pub mod pack; pub mod root; pub mod select; +pub mod stat; pub mod zip; diff --git a/vortex-array/src/scalar_fn/fns/stat.rs b/vortex-array/src/scalar_fn/fns/stat.rs new file mode 100644 index 00000000000..f7557b89a37 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/stat.rs @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Scalar function implementation for aggregate-backed stat expressions. + +use std::fmt::Display; +use std::fmt::Formatter; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; +use crate::arrays::Chunked; +use crate::arrays::ChunkedArray; +use crate::arrays::ConstantArray; +use crate::arrays::chunked::ChunkedArrayExt; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::stats::StatsProvider; +use crate::scalar::Scalar; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; +use crate::stats::legacy::legacy_stat_for_aggregate; + +/// Options for the `stat` scalar function. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct StatOptions { + aggregate_fn: AggregateFnRef, +} + +impl StatOptions { + /// Creates options for the provided aggregate statistic. + pub fn new(aggregate_fn: AggregateFnRef) -> Self { + Self { aggregate_fn } + } + + /// Returns the aggregate function backing this statistic lookup. + pub fn aggregate_fn(&self) -> &AggregateFnRef { + &self.aggregate_fn + } +} + +impl Display for StatOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.aggregate_fn, f) + } +} + +/// Scalar function that evaluates stored aggregate statistics. +#[derive(Clone)] +pub struct StatFn; + +impl ScalarFnVTable for StatFn { + type Options = StatOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new("vortex.stat") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("Invalid child index {} for Stat expression", child_idx), + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "stat(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", {})", options.aggregate_fn()) + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + stat_dtype(options.aggregate_fn(), &arg_dtypes[0]) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + let input = args.get(0)?; + let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?; + + if let Some(chunked) = input.as_opt::() { + let chunks = chunked + .iter_chunks() + .map(|chunk| stat_array(chunk, options.aggregate_fn(), dtype.clone(), chunk.len())) + .collect::>>()?; + return Ok(ChunkedArray::try_new(chunks, dtype)?.into_array()); + } + + stat_array(&input, options.aggregate_fn(), dtype, args.row_count()) + } +} + +fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult { + let Some(dtype) = aggregate_fn.return_dtype(input_dtype) else { + vortex_bail!( + "Aggregate function {} does not support input dtype {}", + aggregate_fn, + input_dtype + ); + }; + Ok(dtype.as_nullable()) +} + +fn stat_array( + array: &ArrayRef, + aggregate_fn: &AggregateFnRef, + dtype: DType, + len: usize, +) -> VortexResult { + let value = legacy_stat_for_aggregate(aggregate_fn) + .and_then(|stat| { + array + .statistics() + .with_typed_stats_set(|stats| stats.get(stat)) + }) + .and_then(|stat| stat.as_exact()) + .and_then(Scalar::into_value); + + let scalar = Scalar::try_new(dtype, value)?; + Ok(ConstantArray::new(scalar, len).into_array()) +} diff --git a/vortex-array/src/scalar_fn/session.rs b/vortex-array/src/scalar_fn/session.rs index b3b4065f5fa..d541c6cbe4c 100644 --- a/vortex-array/src/scalar_fn/session.rs +++ b/vortex-array/src/scalar_fn/session.rs @@ -26,6 +26,7 @@ use crate::scalar_fn::fns::not::Not; use crate::scalar_fn::fns::pack::Pack; use crate::scalar_fn::fns::root::Root; use crate::scalar_fn::fns::select::Select; +use crate::scalar_fn::fns::stat::StatFn; /// Registry of scalar function vtables. pub type ScalarFnRegistry = Registry; @@ -70,6 +71,7 @@ impl Default for ScalarFnSession { this.register(Pack); this.register(Root); this.register(Select); + this.register(StatFn); this } diff --git a/vortex-array/src/stats/expr.rs b/vortex-array/src/stats/expr.rs new file mode 100644 index 00000000000..30d207fcd0b --- /dev/null +++ b/vortex-array/src/stats/expr.rs @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Expression constructors for statistics backed by aggregate functions. + +use crate::aggregate_fn::AggregateFnRef; +use crate::expr::Expression; +use crate::scalar_fn::ScalarFnVTableExt; +pub use crate::scalar_fn::fns::stat::StatFn; +pub use crate::scalar_fn::fns::stat::StatOptions; + +/// Creates an expression that reads a stored aggregate statistic for `expr`. +/// +/// If the statistic is not available in the current stats scope, evaluating the expression returns +/// a nullable all-null array with the aggregate return type. +pub fn stat(expr: Expression, aggregate_fn: AggregateFnRef) -> Expression { + StatFn.new_expr(StatOptions::new(aggregate_fn), [expr]) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexExpect; + use vortex_error::VortexResult; + + use super::stat; + use crate::Canonical; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::AggregateFn; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::sum::Sum; + use crate::arrays::Chunked; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::chunked::ChunkedArrayExt; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::root; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn stat_expr_reads_cached_sum() -> VortexResult<()> { + let array = buffer![1i32, 2, 3].into_array(); + let sum_scalar = Scalar::primitive(6i64, Nullability::Nullable); + array.statistics().set( + Stat::Sum, + Precision::exact(sum_scalar.into_value().vortex_expect("non-null sum")), + ); + + let result = array + .apply(&stat(root(), AggregateFn::new(Sum, EmptyOptions).erased()))? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())? + .into_array(); + + let expected = + ConstantArray::new(Scalar::primitive(6i64, Nullability::Nullable), 3).into_array(); + assert_arrays_eq!(result, expected); + + Ok(()) + } + + #[test] + fn stat_expr_returns_null_when_sum_is_missing() -> VortexResult<()> { + let array = buffer![1i32, 2, 3].into_array(); + + let result = array + .apply(&stat(root(), AggregateFn::new(Sum, EmptyOptions).erased()))? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())? + .into_array(); + + let expected = ConstantArray::new( + Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable)), + 3, + ) + .into_array(); + assert_arrays_eq!(result, expected); + + Ok(()) + } + + #[test] + fn stat_expr_reads_cached_sum_per_chunk() -> VortexResult<()> { + let chunk0 = buffer![1i32, 2].into_array(); + let sum_scalar = Scalar::primitive(3i64, Nullability::Nullable); + chunk0.statistics().set( + Stat::Sum, + Precision::exact(sum_scalar.into_value().vortex_expect("non-null sum")), + ); + let chunk1 = buffer![4i32, 5, 6].into_array(); + let chunked = ChunkedArray::try_new( + vec![chunk0, chunk1], + DType::Primitive(PType::I32, Nullability::NonNullable), + )? + .into_array(); + + let result = chunked.apply(&stat(root(), AggregateFn::new(Sum, EmptyOptions).erased()))?; + + let chunked_result = result + .as_opt::() + .vortex_expect("stat expression should preserve chunked alignment"); + assert_eq!(chunked_result.nchunks(), 2); + + let result = result + .execute::(&mut LEGACY_SESSION.create_execution_ctx())? + .into_array(); + let expected = PrimitiveArray::new( + buffer![3i64, 3, 0, 0, 0], + Validity::from_iter([true, true, false, false, false]), + ) + .into_array(); + assert_arrays_eq!(result, expected); + + Ok(()) + } +} diff --git a/vortex-array/src/stats/legacy.rs b/vortex-array/src/stats/legacy.rs new file mode 100644 index 00000000000..fc5faf2712f --- /dev/null +++ b/vortex-array/src/stats/legacy.rs @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compatibility helpers for stats still stored under the legacy [`Stat`] enum. + +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::aggregate_fn::fns::sum::Sum; +use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; +use crate::expr::stats::Stat; + +pub(crate) fn legacy_stat_for_aggregate(aggregate_fn: &AggregateFnRef) -> Option { + if aggregate_fn.is::() { + return Some(Stat::Sum); + } + if aggregate_fn.is::() { + return Some(Stat::NaNCount); + } + if aggregate_fn.is::() { + return Some(Stat::UncompressedSizeInBytes); + } + + None +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ea8ae6b58a6..1928c78b7c4 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -7,10 +7,13 @@ use arrow_buffer::BooleanBufferBuilder; use arrow_buffer::MutableBuffer; use arrow_buffer::bit_iterator::BitIterator; use enum_iterator::last; +pub use expr::stat; pub use stats_set::*; mod array; +pub mod expr; pub mod flatbuffers; +pub(crate) mod legacy; mod stats_set; pub use array::*;