diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 2b72aab9d90..70ff659eda1 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -54,7 +54,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::fmt(&self, f: &mut core::fmt:: impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum -pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions +pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial @@ -64,7 +64,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult @@ -82,33 +82,19 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> -pub struct vortex_array::aggregate_fn::fns::sum::SumOptions - -impl core::clone::Clone for vortex_array::aggregate_fn::fns::sum::SumOptions - -pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::clone(&self) -> vortex_array::aggregate_fn::fns::sum::SumOptions - -impl core::cmp::Eq for vortex_array::aggregate_fn::fns::sum::SumOptions - -impl core::cmp::PartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions - -pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::eq(&self, other: &vortex_array::aggregate_fn::fns::sum::SumOptions) -> bool - -impl core::fmt::Debug for vortex_array::aggregate_fn::fns::sum::SumOptions +pub struct vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub mod vortex_array::aggregate_fn::kernels -impl core::fmt::Display for vortex_array::aggregate_fn::fns::sum::SumOptions +pub trait vortex_array::aggregate_fn::kernels::DynAggregateKernel: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug -pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_array::aggregate_fn::kernels::DynAggregateKernel::aggregate(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl core::hash::Hash for vortex_array::aggregate_fn::fns::sum::SumOptions +pub trait vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug -pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) +pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::ListViewArray) -> vortex_error::VortexResult> -impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions - -pub struct vortex_array::aggregate_fn::fns::sum::SumPartial +pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate_fixed_size(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult> pub mod vortex_array::aggregate_fn::session @@ -118,11 +104,13 @@ impl vortex_array::aggregate_fn::session::AggregateFnSession pub fn vortex_array::aggregate_fn::session::AggregateFnSession::register(&self, vtable: V) +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::register_aggregate_kernel(&self, array_id: vortex_array::vtable::ArrayId, agg_fn_id: core::option::Option, kernel: &'static dyn vortex_array::aggregate_fn::kernels::DynAggregateKernel) + pub fn vortex_array::aggregate_fn::session::AggregateFnSession::registry(&self) -> &vortex_array::aggregate_fn::session::AggregateFnRegistry impl core::default::Default for vortex_array::aggregate_fn::session::AggregateFnSession -pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> vortex_array::aggregate_fn::session::AggregateFnSession +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> Self impl core::fmt::Debug for vortex_array::aggregate_fn::session::AggregateFnSession @@ -324,7 +312,7 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum -pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions +pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial @@ -334,7 +322,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 7ed2df38edc..d3fab024739 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -105,9 +105,17 @@ impl DynAccumulator for Accumulator { break; } - let kernel_key = (self.vtable.id(), batch.encoding_id()); - if let Some(kernel) = kernels.read().get(&kernel_key) - && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)? + let kernels_r = kernels.read(); + let batch_id = batch.encoding_id(); + if let Some(result) = kernels_r + .get(&(batch_id.clone(), Some(self.aggregate_fn.id()))) + .or_else(|| kernels_r.get(&(batch_id, None))) + .and_then(|kernel| { + kernel + .aggregate(&self.aggregate_fn, &batch, &mut ctx) + .transpose() + }) + .transpose()? { vortex_ensure!( result.dtype() == &self.partial_dtype, diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 8a782cc57bd..ff85da637b9 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -164,21 +164,27 @@ impl GroupedAccumulator { break; } - let kernel_key = (self.vtable.id(), elements.encoding_id()); - if let Some(kernel) = kernels.read().get(&kernel_key) { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - ListViewArray::new_unchecked( - elements.clone(), - groups.offsets().clone(), - groups.sizes().clone(), - groups.validity().clone(), - ) - }; - - if let Some(result) = kernel.grouped_aggregate(&self.aggregate_fn, &groups)? { - return self.push_result(result); - } + let kernels_r = kernels.read(); + if let Some(result) = kernels_r + .get(&(elements.encoding_id(), Some(self.aggregate_fn.id()))) + .or_else(|| kernels_r.get(&(elements.encoding_id(), None))) + .and_then(|kernel| { + // SAFETY: we assume that elements execution is safe + let groups = unsafe { + ListViewArray::new_unchecked( + elements.clone(), + groups.offsets().clone(), + groups.sizes().clone(), + groups.validity().clone(), + ) + }; + kernel + .grouped_aggregate(&self.aggregate_fn, &groups) + .transpose() + }) + .transpose()? + { + return self.push_result(result); } // Execute one step and try again @@ -244,23 +250,28 @@ impl GroupedAccumulator { break; } - let kernel_key = (self.vtable.id(), elements.encoding_id()); - if let Some(kernel) = kernels.read().get(&kernel_key) { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - FixedSizeListArray::new_unchecked( - elements.clone(), - groups.list_size(), - groups.validity().clone(), - groups.len(), - ) - }; - - if let Some(result) = - kernel.grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)? - { - return self.push_result(result); - } + let kernels_r = kernels.read(); + if let Some(result) = kernels_r + .get(&(elements.encoding_id(), Some(self.aggregate_fn.id()))) + .or_else(|| kernels_r.get(&(elements.encoding_id(), None))) + .and_then(|kernel| { + // SAFETY: we assume that elements execution is safe + let groups = unsafe { + FixedSizeListArray::new_unchecked( + elements.clone(), + groups.list_size(), + groups.validity().clone(), + groups.len(), + ) + }; + + kernel + .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups) + .transpose() + }) + .transpose()? + { + return self.push_result(result); } // Execute one step and try again diff --git a/vortex-array/src/aggregate_fn/fns/sum.rs b/vortex-array/src/aggregate_fn/fns/sum.rs index 1a4c73f8084..5b16b585e37 100644 --- a/vortex-array/src/aggregate_fn/fns/sum.rs +++ b/vortex-array/src/aggregate_fn/fns/sum.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::fmt::Display; -use std::fmt::Formatter; use std::ops::BitAnd; use itertools::Itertools; @@ -19,6 +17,7 @@ use crate::Canonical; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::EmptyOptions; use crate::arrays::BoolArray; use crate::arrays::DecimalArray; use crate::arrays::PrimitiveArray; @@ -34,23 +33,8 @@ use crate::scalar::Scalar; #[derive(Clone, Debug)] pub struct Sum; -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct SumOptions { - checked: bool, -} - -impl Display for SumOptions { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - if self.checked { - write!(f, "checked") - } else { - write!(f, "unchecked") - } - } -} - impl AggregateFnVTable for Sum { - type Options = SumOptions; + type Options = EmptyOptions; type Partial = SumPartial; fn id(&self) -> AggregateFnId { @@ -69,25 +53,16 @@ impl AggregateFnVTable for Sum { fn empty_partial( &self, - options: &Self::Options, + _options: &Self::Options, input_dtype: &DType, ) -> VortexResult { let return_dtype = Stat::Sum .dtype(input_dtype) .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?; - let initial = match &return_dtype { - DType::Primitive(ptype, _) => match ptype { - PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0), - PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0), - PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0), - }, - DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)), - _ => vortex_panic!("Unsupported sum type"), - }; + let initial = make_zero_state(&return_dtype); Ok(SumPartial { - checked: options.checked, return_dtype, current: Some(initial), }) @@ -95,9 +70,10 @@ impl AggregateFnVTable for Sum { fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { if other.is_null() { + // A null partial means the sub-accumulator saturated (overflow). + partial.current = None; return Ok(()); } - let checked = partial.checked; let Some(ref mut inner) = partial.current else { return Ok(()); }; @@ -106,21 +82,21 @@ impl AggregateFnVTable for Sum { let val = other .as_primitive() .typed_value::() - .ok_or_else(|| vortex_err!("Expected u64 scalar for unsigned sum merge"))?; - add_u64(acc, val, checked) + .vortex_expect("checked non-null"); + checked_add_u64(acc, val) } SumState::Signed(acc) => { let val = other .as_primitive() .typed_value::() - .ok_or_else(|| vortex_err!("Expected i64 scalar for signed sum merge"))?; - add_i64(acc, val, checked) + .vortex_expect("checked non-null"); + checked_add_i64(acc, val) } SumState::Float(acc) => { let val = other .as_primitive() .typed_value::() - .ok_or_else(|| vortex_err!("Expected f64 scalar for float sum merge"))?; + .vortex_expect("checked non-null"); *acc += val; false } @@ -128,7 +104,7 @@ impl AggregateFnVTable for Sum { let val = other .as_decimal() .decimal_value() - .ok_or_else(|| vortex_err!("Expected decimal scalar for decimal sum merge"))?; + .vortex_expect("checked non-null"); match acc.checked_add(&val) { Some(r) => { *acc = r; @@ -176,15 +152,14 @@ impl AggregateFnVTable for Sum { batch: &Canonical, _ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let checked = partial.checked; let mut inner = match partial.current.take() { Some(inner) => inner, None => return Ok(()), }; let result = match batch { - Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, checked), - Canonical::Bool(b) => accumulate_bool(&mut inner, b, checked), + Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), + Canonical::Bool(b) => accumulate_bool(&mut inner, b), Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), }; @@ -212,7 +187,6 @@ impl AggregateFnVTable for Sum { /// The group state for a sum aggregate, containing the accumulated value and configuration /// needed for reset/result without external context. pub struct SumPartial { - checked: bool, return_dtype: DType, /// The current accumulated state, or `None` if saturated (checked overflow). current: Option, @@ -241,62 +215,45 @@ fn make_zero_state(return_dtype: &DType) -> SumState { } } -/// Add `val` to `acc`, returning true if overflow occurred (checked mode) or wrapping (unchecked). -fn add_u64(acc: &mut u64, val: u64, checked: bool) -> bool { - if checked { - match acc.checked_add(val) { - Some(r) => { - *acc = r; - false - } - None => true, +/// Checked add for u64, returning true if overflow occurred. +#[inline(always)] +fn checked_add_u64(acc: &mut u64, val: u64) -> bool { + match acc.checked_add(val) { + Some(r) => { + *acc = r; + false } - } else { - *acc = acc.wrapping_add(val); - false + None => true, } } -fn add_i64(acc: &mut i64, val: i64, checked: bool) -> bool { - if checked { - match acc.checked_add(val) { - Some(r) => { - *acc = r; - false - } - None => true, +/// Checked add for i64, returning true if overflow occurred. +#[inline(always)] +fn checked_add_i64(acc: &mut i64, val: i64) -> bool { + match acc.checked_add(val) { + Some(r) => { + *acc = r; + false } - } else { - *acc = acc.wrapping_add(val); - false + None => true, } } -/// Accumulate a primitive array into the sum state. -/// Returns Ok(true) if saturated (overflow), Ok(false) if not. -fn accumulate_primitive( - inner: &mut SumState, - p: &PrimitiveArray, - checked: bool, -) -> VortexResult { +fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { let mask = p.validity_mask()?; match mask.bit_buffer() { AllOr::None => Ok(false), - AllOr::All => accumulate_primitive_all(inner, p, checked), - AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity, checked), + AllOr::All => accumulate_primitive_all(inner, p), + AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity), } } -fn accumulate_primitive_all( - inner: &mut SumState, - p: &PrimitiveArray, - checked: bool, -) -> VortexResult { +fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { match inner { SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), unsigned: |T| { for &v in p.as_slice::() { - if add_u64(acc, v.to_u64().vortex_expect("unsigned to u64"), checked) { + if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { return Ok(true); } } @@ -309,7 +266,7 @@ fn accumulate_primitive_all( unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, signed: |T| { for &v in p.as_slice::() { - if add_i64(acc, v.to_i64().vortex_expect("signed to i64"), checked) { + if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { return Ok(true); } } @@ -335,13 +292,12 @@ fn accumulate_primitive_valid( inner: &mut SumState, p: &PrimitiveArray, validity: &vortex_buffer::BitBuffer, - checked: bool, ) -> VortexResult { match inner { SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), unsigned: |T| { for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { - if valid && add_u64(acc, v.to_u64().vortex_expect("unsigned to u64"), checked) { + if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) { return Ok(true); } } @@ -354,7 +310,7 @@ fn accumulate_primitive_valid( unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") }, signed: |T| { for (&v, valid) in p.as_slice::().iter().zip_eq(validity.iter()) { - if valid && add_i64(acc, v.to_i64().vortex_expect("signed to i64"), checked) { + if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) { return Ok(true); } } @@ -378,9 +334,7 @@ fn accumulate_primitive_valid( } } -/// Accumulate a boolean array into the sum state (counts true values as u64). -/// Returns Ok(true) if saturated (overflow), Ok(false) if not. -fn accumulate_bool(inner: &mut SumState, b: &BoolArray, checked: bool) -> VortexResult { +fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult { let SumState::Unsigned(acc) = inner else { vortex_panic!("expected unsigned sum state for bool input"); }; @@ -392,7 +346,7 @@ fn accumulate_bool(inner: &mut SumState, b: &BoolArray, checked: bool) -> Vortex AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64, }; - Ok(add_u64(acc, true_count, checked)) + Ok(checked_add_u64(acc, true_count)) } /// Accumulate a decimal array into the sum state. @@ -440,9 +394,9 @@ mod tests { use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::sum::Sum; - use crate::aggregate_fn::fns::sum::SumOptions; use crate::arrays::BoolArray; use crate::arrays::FixedSizeListArray; use crate::arrays::PrimitiveArray; @@ -457,16 +411,8 @@ mod tests { VortexSession::empty() } - fn checked_opts() -> SumOptions { - SumOptions { checked: true } - } - - fn unchecked_opts() -> SumOptions { - SumOptions { checked: false } - } - - fn run_sum(batch: &ArrayRef, options: &SumOptions) -> VortexResult { - let mut acc = Accumulator::try_new(Sum, options.clone(), batch.dtype().clone(), session())?; + fn run_sum(batch: &ArrayRef) -> VortexResult { + let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?; acc.accumulate(batch)?; acc.finish() } @@ -476,7 +422,7 @@ mod tests { #[test] fn sum_i32() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert_eq!(result.as_primitive().typed_value::(), Some(10)); Ok(()) } @@ -484,7 +430,7 @@ mod tests { #[test] fn sum_u8() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert_eq!(result.as_primitive().typed_value::(), Some(60)); Ok(()) } @@ -493,7 +439,7 @@ mod tests { fn sum_f64() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert_eq!(result.as_primitive().typed_value::(), Some(7.0)); Ok(()) } @@ -501,7 +447,7 @@ mod tests { #[test] fn sum_with_nulls() -> VortexResult<()> { let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert_eq!(result.as_primitive().typed_value::(), Some(6)); Ok(()) } @@ -510,7 +456,7 @@ mod tests { fn sum_all_null() -> VortexResult<()> { // Arrow semantics: sum of all nulls is zero (identity element) let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -520,7 +466,7 @@ mod tests { #[test] fn sum_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -529,7 +475,7 @@ mod tests { #[test] fn sum_empty_f64_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); Ok(()) @@ -540,7 +486,7 @@ mod tests { #[test] fn sum_multi_batch() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); acc.accumulate(&batch1)?; @@ -556,7 +502,7 @@ mod tests { #[test] fn sum_finish_resets_state() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); acc.accumulate(&batch1)?; @@ -575,7 +521,7 @@ mod tests { #[test] fn sum_state_merge() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = Sum.empty_partial(&checked_opts(), &dtype)?; + let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?; let scalar1 = Scalar::primitive(100i64, Nullability::Nullable); Sum.combine_partials(&mut state, scalar1)?; @@ -593,7 +539,7 @@ mod tests { #[test] fn sum_checked_overflow() -> VortexResult<()> { let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - let result = run_sum(&arr, &checked_opts())?; + let result = run_sum(&arr)?; assert!(result.is_null()); Ok(()) } @@ -601,7 +547,7 @@ mod tests { #[test] fn sum_checked_overflow_is_saturated() -> VortexResult<()> { let dtype = DType::Primitive(PType::I64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; assert!(!acc.is_saturated()); let batch = @@ -615,23 +561,12 @@ mod tests { Ok(()) } - #[test] - fn sum_unchecked_wrapping() -> VortexResult<()> { - let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); - let result = run_sum(&arr, &unchecked_opts())?; - assert_eq!( - result.as_primitive().typed_value::(), - Some(i64::MAX.wrapping_add(1)) - ); - Ok(()) - } - // Boolean sum tests #[test] fn sum_bool_all_true() -> VortexResult<()> { let arr: BoolArray = [true, true, true].into_iter().collect(); - let result = run_sum(&arr.into_array(), &checked_opts())?; + let result = run_sum(&arr.into_array())?; assert_eq!(result.as_primitive().typed_value::(), Some(3)); Ok(()) } @@ -639,7 +574,7 @@ mod tests { #[test] fn sum_bool_mixed() -> VortexResult<()> { let arr: BoolArray = [true, false, true, false, true].into_iter().collect(); - let result = run_sum(&arr.into_array(), &checked_opts())?; + let result = run_sum(&arr.into_array())?; assert_eq!(result.as_primitive().typed_value::(), Some(3)); Ok(()) } @@ -647,7 +582,7 @@ mod tests { #[test] fn sum_bool_all_false() -> VortexResult<()> { let arr: BoolArray = [false, false, false].into_iter().collect(); - let result = run_sum(&arr.into_array(), &checked_opts())?; + let result = run_sum(&arr.into_array())?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -655,7 +590,7 @@ mod tests { #[test] fn sum_bool_with_nulls() -> VortexResult<()> { let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]); - let result = run_sum(&arr.into_array(), &checked_opts())?; + let result = run_sum(&arr.into_array())?; assert_eq!(result.as_primitive().typed_value::(), Some(2)); Ok(()) } @@ -664,7 +599,7 @@ mod tests { fn sum_bool_all_null() -> VortexResult<()> { // Arrow semantics: sum of all nulls is zero (identity element) let arr = BoolArray::from_iter([None::, None, None]); - let result = run_sum(&arr.into_array(), &checked_opts())?; + let result = run_sum(&arr.into_array())?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) } @@ -672,7 +607,7 @@ mod tests { #[test] fn sum_bool_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -681,7 +616,7 @@ mod tests { #[test] fn sum_bool_finish_resets_state() -> VortexResult<()> { let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, checked_opts(), dtype, session())?; + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?; let batch1: BoolArray = [true, true, false].into_iter().collect(); acc.accumulate(&batch1.into_array())?; @@ -697,7 +632,7 @@ mod tests { #[test] fn sum_bool_return_dtype() -> VortexResult<()> { - let dtype = Sum.return_dtype(&checked_opts(), &DType::Bool(Nullability::NonNullable))?; + let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?; assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); Ok(()) } @@ -706,7 +641,7 @@ mod tests { fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { let mut acc = - GroupedAccumulator::try_new(Sum, checked_opts(), elem_dtype.clone(), session())?; + GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?; acc.accumulate_list(groups)?; acc.finish() } @@ -793,7 +728,7 @@ mod tests { #[test] fn grouped_sum_finish_resets() -> VortexResult<()> { let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = GroupedAccumulator::try_new(Sum, checked_opts(), elem_dtype, session())?; + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?; // First batch: [[1, 2], [3, 4]] let elements1 = diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index 02d481ec180..ec659203c7a 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -9,6 +9,7 @@ use std::fmt::Debug; use vortex_error::VortexResult; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; use crate::arrays::FixedSizeListArray; use crate::arrays::ListViewArray; @@ -23,6 +24,7 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { &self, aggregate_fn: &AggregateFnRef, batch: &ArrayRef, + ctx: &mut ExecutionCtx, ) -> VortexResult>; } diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs index b3b922592ec..4f89d5a4178 100644 --- a/vortex-array/src/aggregate_fn/mod.rs +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -30,7 +30,7 @@ mod options; pub use options::*; pub mod fns; -mod kernels; +pub mod kernels; pub mod session; /// A unique identifier for an aggregate function. diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 8813544207c..2c018dbefd6 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -14,19 +14,37 @@ use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; +use crate::arrays::ChunkedVTable; +use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate; use crate::vtable::ArrayId; /// Registry of aggregate function vtables. pub type AggregateFnRegistry = Registry; /// Session state for aggregate function vtables. -#[derive(Debug, Default)] +#[derive(Debug)] pub struct AggregateFnSession { registry: AggregateFnRegistry, - pub(super) kernels: RwLock>, - pub(super) grouped_kernels: - RwLock>, + pub(super) kernels: RwLock>, + pub(super) grouped_kernels: RwLock>, +} + +type KernelKey = (ArrayId, Option); + +impl Default for AggregateFnSession { + fn default() -> Self { + let this = Self { + registry: AggregateFnRegistry::default(), + kernels: RwLock::new(HashMap::default()), + grouped_kernels: RwLock::new(HashMap::default()), + }; + + // Register the built-in aggregate kernels. + this.register_aggregate_kernel(ChunkedVTable::ID, None, &ChunkedArrayAggregate); + + this + } } impl AggregateFnSession { @@ -41,6 +59,16 @@ impl AggregateFnSession { self.registry .register(vtable.id(), Arc::new(vtable) as AggregateFnPluginRef); } + + /// Register an aggregate function kernel for a specific aggregate function and array type. + pub fn register_aggregate_kernel( + &self, + array_id: ArrayId, + agg_fn_id: Option, + kernel: &'static dyn DynAggregateKernel, + ) { + self.kernels.write().insert((array_id, agg_fn_id), kernel); + } } /// Extension trait for accessing aggregate function session data. diff --git a/vortex-array/src/arrays/chunked/compute/aggregate.rs b/vortex-array/src/arrays/chunked/compute/aggregate.rs new file mode 100644 index 00000000000..08f600dd00d --- /dev/null +++ b/vortex-array/src/arrays/chunked/compute/aggregate.rs @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::kernels::DynAggregateKernel; +use crate::arrays::ChunkedVTable; +use crate::scalar::Scalar; + +#[derive(Debug)] +pub struct ChunkedArrayAggregate; + +impl DynAggregateKernel for ChunkedArrayAggregate { + fn aggregate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(chunked) = batch.as_opt::() else { + return Ok(None); + }; + + let mut acc = aggregate_fn.accumulator(chunked.dtype(), ctx.session())?; + for chunk in chunked.chunks() { + acc.accumulate(chunk)?; + } + Ok(Some(acc.finish()?)) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::Buffer; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::IntoArray; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::sum::Sum; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::scalar::Scalar; + + fn session() -> VortexSession { + VortexSession::empty() + } + + fn run_sum(batch: &crate::ArrayRef) -> VortexResult { + let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?; + acc.accumulate(batch)?; + acc.finish() + } + + #[test] + fn sum_chunked_i32() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + buffer![1i32, 2, 3].into_array(), + buffer![4i32, 5, 6].into_array(), + ], + DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(21)); + Ok(()) + } + + #[test] + fn sum_chunked_f64() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + buffer![1.5f64, 2.5].into_array(), + buffer![3.0f64].into_array(), + ], + DType::Primitive(PType::F64, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(7.0)); + Ok(()) + } + + #[test] + fn sum_chunked_with_nulls() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(), + PrimitiveArray::from_option_iter([None, Some(5)]).into_array(), + ], + DType::Primitive(PType::I32, Nullability::Nullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(9)); + Ok(()) + } + + #[test] + fn sum_chunked_all_null() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + PrimitiveArray::from_option_iter([None::, None]).into_array(), + PrimitiveArray::from_option_iter([None::]).into_array(), + ], + DType::Primitive(PType::I32, Nullability::Nullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_chunked_single_chunk() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![buffer![10i32, 20, 30].into_array()], + DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(60)); + Ok(()) + } + + #[test] + fn sum_chunked_empty_chunks() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let chunked = ChunkedArray::try_new( + vec![ + Buffer::::empty().into_array(), + buffer![1i32, 2, 3].into_array(), + Buffer::::empty().into_array(), + buffer![4i32, 5].into_array(), + Buffer::::empty().into_array(), + ], + dtype, + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(15)); + Ok(()) + } + + #[test] + fn sum_chunked_all_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let chunked = ChunkedArray::try_new(vec![], dtype)?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn sum_chunked_many_small_chunks() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + buffer![1i32].into_array(), + buffer![2i32].into_array(), + buffer![3i32].into_array(), + buffer![4i32].into_array(), + buffer![5i32].into_array(), + ], + DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(15)); + Ok(()) + } + + #[test] + fn sum_chunked_u64() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![ + buffer![100u64, 200].into_array(), + buffer![300u64].into_array(), + ], + DType::Primitive(PType::U64, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(600)); + Ok(()) + } + + #[test] + fn sum_chunked_bool() -> VortexResult<()> { + let b1: BoolArray = [true, false, true].into_iter().collect(); + let b2: BoolArray = [true, true].into_iter().collect(); + let chunked = ChunkedArray::try_new( + vec![b1.into_array(), b2.into_array()], + DType::Bool(Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(4)); + Ok(()) + } + + #[test] + fn sum_chunked_bool_with_nulls() -> VortexResult<()> { + let b1 = BoolArray::from_iter([Some(true), None, Some(true)]); + let b2 = BoolArray::from_iter([Some(false), None]); + let chunked = ChunkedArray::try_new( + vec![b1.into_array(), b2.into_array()], + DType::Bool(Nullability::Nullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(2)); + Ok(()) + } + + #[test] + fn sum_chunked_checked_overflow() -> VortexResult<()> { + let chunked = ChunkedArray::try_new( + vec![buffer![i64::MAX].into_array(), buffer![1i64].into_array()], + DType::Primitive(PType::I64, Nullability::NonNullable), + )?; + let result = run_sum(&chunked.into_array())?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn sum_chunked_nested() -> VortexResult<()> { + let inner = ChunkedArray::try_new( + vec![buffer![1i32, 2].into_array(), buffer![3i32].into_array()], + DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + let outer = ChunkedArray::try_new( + vec![inner.into_array(), buffer![4i32, 5, 6].into_array()], + DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + let result = run_sum(&outer.into_array())?; + assert_eq!(result.as_primitive().typed_value::(), Some(21)); + Ok(()) + } +} diff --git a/vortex-array/src/arrays/chunked/compute/mod.rs b/vortex-array/src/arrays/chunked/compute/mod.rs index 5b018c99ea9..791e261745c 100644 --- a/vortex-array/src/arrays/chunked/compute/mod.rs +++ b/vortex-array/src/arrays/chunked/compute/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +pub(crate) mod aggregate; mod cast; mod fill_null; mod filter; diff --git a/vortex-array/src/arrays/chunked/mod.rs b/vortex-array/src/arrays/chunked/mod.rs index 4ef5ed10cf2..aee36910daa 100644 --- a/vortex-array/src/arrays/chunked/mod.rs +++ b/vortex-array/src/arrays/chunked/mod.rs @@ -4,7 +4,7 @@ mod array; pub use array::ChunkedArray; -mod compute; +pub(crate) mod compute; mod vtable; pub use vtable::ChunkedVTable;