diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index 2d46a5cce8a..e99e4619143 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -15,6 +15,7 @@ use vortex_array::VortexSessionExecute; use vortex_array::aggregate_fn::AggregateFnVTable; use vortex_array::aggregate_fn::DynGroupedAccumulator; use vortex_array::aggregate_fn::EmptyOptions; +use vortex_array::aggregate_fn::GroupIds; use vortex_array::aggregate_fn::GroupedAccumulator; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; @@ -45,24 +46,22 @@ fn total_element_count(group_sizes: &[usize]) -> usize { struct DenseGroupedInput { values: ArrayRef, - group_ids: Vec, - num_groups: usize, + group_ids: GroupIds, } fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput { assert_eq!(values.len(), total_element_count(group_sizes)); - let group_ids = group_sizes - .iter() - .enumerate() - .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)) - .collect(); + let group_ids = GroupIds::from_iter( + group_sizes + .iter() + .enumerate() + .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)), + group_sizes.len(), + ) + .unwrap(); - DenseGroupedInput { - values, - group_ids, - num_groups: group_sizes.len(), - } + DenseGroupedInput { values, group_ids } } fn i32_nullable_all_valid_input() -> DenseGroupedInput { @@ -142,14 +141,14 @@ where { let mut acc = GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap(); + let num_groups = input.group_ids.num_groups(); acc.accumulate( &input.values, &input.group_ids, - input.num_groups, &mut LEGACY_SESSION.create_execution_ctx(), ) .unwrap(); - divan::black_box(acc.finish(input.num_groups).unwrap()) + divan::black_box(acc.finish(num_groups).unwrap()) } #[divan::bench] diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 7a614ceed63..46064e3b000 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -7,7 +7,6 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::ArrayRef; -use crate::Columnar; use crate::ExecutionCtx; use crate::IntoArray; use crate::aggregate_fn::Accumulator; @@ -15,18 +14,92 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; +use crate::array::ArrayId; +use crate::arrays::PrimitiveArray; use crate::builders::builder_with_capacity; use crate::columnar::AnyColumnar; use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; use crate::executor::max_iterations; use crate::scalar::Scalar; +use crate::validity::Validity; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; -/// An accumulator used for computing aggregates over dense group ids. +/// Encoded group ids parallel to a grouped aggregate input batch. +/// +/// The array must contain non-null `u32` ordinals. The ordinals are dense state slots in +/// `0..num_groups`, not raw group keys. Range validation may require executing the encoded array, +/// so kernels that can prove the invariant from encoded metadata should avoid materializing and +/// otherwise call [`Self::validated_ids`] before indexing group state. +#[derive(Clone, Debug)] +pub struct GroupIds { + ids: ArrayRef, + num_groups: usize, +} + +impl GroupIds { + /// Create group ids from an encoded non-null `u32` array. + pub fn new(ids: ArrayRef, num_groups: usize) -> VortexResult { + validate_num_groups(num_groups)?; + vortex_ensure!( + ids.dtype() == &DType::Primitive(PType::U32, Nullability::NonNullable), + "Group ids must be non-nullable u32, got {}", + ids.dtype() + ); + Ok(Self { ids, num_groups }) + } + + /// Create group ids from a materialized buffer. + pub fn from_buffer(ids: Buffer, num_groups: usize) -> VortexResult { + Self::new( + PrimitiveArray::new(ids, Validity::NonNullable).into_array(), + num_groups, + ) + } + + /// Create group ids from materialized values. + pub fn from_iter(ids: impl IntoIterator, num_groups: usize) -> VortexResult { + Self::from_buffer(Buffer::from_iter(ids), num_groups) + } + + /// Return the encoded ids array. + pub fn ids(&self) -> &ArrayRef { + &self.ids + } + + /// Return the number of dense group state slots. + pub fn num_groups(&self) -> usize { + self.num_groups + } + + /// Return the number of ids. + pub fn len(&self) -> usize { + self.ids.len() + } + + /// Return whether there are no ids. + pub fn is_empty(&self) -> bool { + self.ids.is_empty() + } + + /// Return the encoding id for kernel dispatch. + pub fn encoding_id(&self) -> ArrayId { + self.ids.encoding_id() + } + + /// Execute the ids to a native buffer and validate every id is in range. + pub fn validated_ids(&self, ctx: &mut ExecutionCtx) -> VortexResult> { + let ids = self.ids.clone().execute::>(ctx)?; + validate_group_ids(ids.as_ref(), self.num_groups)?; + Ok(ids) + } +} + +/// An accumulator used for computing aggregates over group ids. /// /// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches /// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a @@ -88,54 +161,25 @@ impl GroupedAccumulator { Ok(()) } - fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { - validate_num_groups(num_groups)?; - for &group_id in group_ids { - vortex_ensure!( - (group_id as usize) < num_groups, - "Group id {} out of range for {} groups", - group_id, - num_groups - ); - } - Ok(()) - } - - fn accumulate_kernel_result( - &mut self, - result: GroupedAggregateKernelResult, - num_groups: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx) - } - fn try_accumulate_kernel( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult { let session = ctx.session().clone(); - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_encoding_kernel(batch.encoding_id(), self.aggregate_fn.id()) - && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? - { - self.accumulate_kernel_result(result, num_groups, ctx)?; - return Ok(true); - } - - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_kernel(self.aggregate_fn.id()) - && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? - { - self.accumulate_kernel_result(result, num_groups, ctx)?; + if let Some(kernel) = session.aggregate_fns().find_grouped_kernel( + self.aggregate_fn.id(), + batch.encoding_id(), + group_ids.encoding_id(), + ) && kernel.grouped_accumulate( + &self.aggregate_fn, + batch, + group_ids, + &mut self.partials, + ctx, + )? { return Ok(true); } @@ -198,18 +242,31 @@ fn validate_num_groups(num_groups: usize) -> VortexResult<()> { Ok(()) } +fn validate_group_ids(group_ids: &[u32], num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + for &group_id in group_ids { + vortex_ensure!( + (group_id as usize) < num_groups, + "Group id {} out of range for {} groups", + group_id, + num_groups + ); + } + Ok(()) +} + /// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the /// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { /// Accumulate a values batch into dense group state. /// /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in - /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch. + /// `0..group_ids.num_groups()`; ids may repeat, appear out of order, or be absent from a + /// given batch. fn accumulate( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()>; @@ -220,8 +277,7 @@ pub trait DynGroupedAccumulator: 'static + Send { fn accumulate_partials( &mut self, partials: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()>; @@ -254,10 +310,10 @@ impl DynGroupedAccumulator for GroupedAccumulator { fn accumulate( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let num_groups = group_ids.num_groups(); vortex_ensure!( batch.dtype() == &self.dtype, "Input DType mismatch: expected {}, got {}", @@ -271,56 +327,43 @@ impl DynGroupedAccumulator for GroupedAccumulator { group_ids.len() ); - self.validate_group_ids(group_ids, num_groups)?; self.ensure_groups(num_groups)?; - if self.try_accumulate_kernel(batch, group_ids, num_groups, ctx)? { - return Ok(()); - } - - if self.vtable.try_accumulate_grouped( - &mut self.partials[..num_groups], - batch, - group_ids, - ctx, - )? { + if self.try_accumulate_kernel(batch, group_ids, ctx)? { return Ok(()); } let input = batch.clone(); let mut batch = batch.clone(); + let mut tried_current = true; for _ in 0..max_iterations() { if batch.is::() { break; } - if self.try_accumulate_kernel(&batch, group_ids, num_groups, ctx)? { + if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? { return Ok(()); } batch = batch.execute(ctx)?; + tried_current = false; } - let columnar = batch.clone().execute::(ctx)?; - if self.vtable.accumulate_grouped( - &mut self.partials[..num_groups], - &columnar, - group_ids, - ctx, - )? { + if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? { return Ok(()); } - self.accumulate_fallback(&input, group_ids, ctx) + let group_ids = group_ids.validated_ids(ctx)?; + self.accumulate_fallback(&input, group_ids.as_ref(), ctx) } fn accumulate_partials( &mut self, partials: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let num_groups = group_ids.num_groups(); vortex_ensure!( partials.dtype() == &self.partial_dtype, "Partial DType mismatch: expected {}, got {}", @@ -334,7 +377,7 @@ impl DynGroupedAccumulator for GroupedAccumulator { group_ids.len() ); - self.validate_group_ids(group_ids, num_groups)?; + let group_ids = group_ids.validated_ids(ctx)?; self.ensure_groups(num_groups)?; for (row_idx, &group_id) in group_ids.iter().enumerate() { diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index 03e2b1b49ae..68ea4e05d26 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -3,20 +3,36 @@ use vortex_error::VortexResult; +use super::Count; use crate::ArrayRef; use crate::ExecutionCtx; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::GroupIds; +use crate::aggregate_fn::kernels::GroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter; -pub(super) fn try_accumulate_grouped( - states: &mut [u64], - batch: &ArrayRef, - group_ids: &[u32], - ctx: &mut ExecutionCtx, -) -> VortexResult { - let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; - for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { - if valid { - states[group_id as usize] += 1; +pub(crate) static COUNT_GROUPED_KERNEL: GroupedAggregateKernelAdapter = + GroupedAggregateKernelAdapter::new(CountGroupedKernel); + +#[derive(Debug)] +pub(crate) struct CountGroupedKernel; + +impl GroupedAggregateKernel for CountGroupedKernel { + fn grouped_accumulate( + &self, + _options: &EmptyOptions, + states: &mut [u64], + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let group_ids = group_ids.validated_ids(ctx)?; + let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; + for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + if valid { + states[group_id as usize] += 1; + } } + Ok(true) } - Ok(true) } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e53a378b5a9..c6a9c27d52f 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod grouped; +pub(crate) use grouped::COUNT_GROUPED_KERNEL; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -95,16 +96,6 @@ impl AggregateFnVTable for Count { Ok(true) } - fn try_accumulate_grouped( - &self, - states: &mut [Self::Partial], - batch: &ArrayRef, - group_ids: &[u32], - ctx: &mut ExecutionCtx, - ) -> VortexResult { - grouped::try_accumulate_grouped(states, batch, group_ids, ctx) - } - fn accumulate( &self, _partial: &mut Self::Partial, @@ -139,6 +130,7 @@ mod tests { use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupIds; use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; @@ -258,10 +250,10 @@ mod tests { num_groups: usize, ) -> VortexResult { let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + let group_ids = GroupIds::from_iter(group_ids.iter().copied(), num_groups)?; acc.accumulate( values, - group_ids, - num_groups, + &group_ids, &mut LEGACY_SESSION.create_execution_ctx(), )?; acc.finish(num_groups) @@ -307,13 +299,30 @@ mod tests { Ok(()) } + #[test] + fn grouped_count_constant_group_ids() -> VortexResult<()> { + let values = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(); + let group_ids = GroupIds::new(ConstantArray::new(1u32, values.len()).into_array(), 3)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + + acc.accumulate(&values, &group_ids, &mut ctx)?; + let actual = acc.finish(3)?; + + let expected = PrimitiveArray::from_iter([0u64, 3, 0]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let group_ids = GroupIds::from_iter([0u32, 2], 2)?; - assert!(acc.accumulate(&values, &[0, 2], 2, &mut ctx).is_err()); + assert!(acc.accumulate(&values, &group_ids, &mut ctx).is_err()); Ok(()) } @@ -324,7 +333,8 @@ mod tests { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let mut left = GroupedAccumulator::try_new(Count, EmptyOptions, dtype.clone())?; - left.accumulate_partials(&partials, &[0, 1, 1], 2, &mut ctx)?; + let group_ids = GroupIds::from_iter([0u32, 1, 1], 2)?; + left.accumulate_partials(&partials, &group_ids, &mut ctx)?; let mut right = GroupedAccumulator::try_new(Count, EmptyOptions, dtype)?; right.merge_group(0, &left, 1)?; diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 81304f1eb9f..e7a73059fc3 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -9,6 +9,7 @@ use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; +use super::Sum; use super::SumPartial; use super::SumState; use super::checked_add_i64; @@ -16,8 +17,15 @@ use super::checked_add_u64; use super::primitive::sum_float_all; use super::primitive::sum_signed_all; use super::primitive::sum_unsigned_all; +use crate::ArrayRef; use crate::ExecutionCtx; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::GroupIds; +use crate::aggregate_fn::kernels::GroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter; +use crate::arrays::Bool; use crate::arrays::BoolArray; +use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::arrays::bool::BoolArrayExt; use crate::dtype::NativePType; @@ -25,6 +33,39 @@ use crate::match_each_native_ptype; const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4; +pub(crate) static SUM_GROUPED_KERNEL: GroupedAggregateKernelAdapter = + GroupedAggregateKernelAdapter::new(SumGroupedKernel); + +#[derive(Debug)] +pub(crate) struct SumGroupedKernel; + +impl GroupedAggregateKernel for SumGroupedKernel { + fn grouped_accumulate( + &self, + _options: &EmptyOptions, + partials: &mut [SumPartial], + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + if let Some(primitive) = batch.as_opt::() { + let group_ids = group_ids.validated_ids(ctx)?; + let primitive = primitive.into_owned(); + accumulate_grouped_primitive(partials, &primitive, group_ids.as_ref(), ctx)?; + return Ok(true); + } + + if let Some(bools) = batch.as_opt::() { + let group_ids = group_ids.validated_ids(ctx)?; + let bools = bools.into_owned(); + accumulate_grouped_bool(partials, &bools, group_ids.as_ref(), ctx)?; + return Ok(true); + } + + Ok(false) + } +} + fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { match validity.indices() { AllOr::All => { diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index eff487d55e8..207b8140922 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -7,6 +7,7 @@ mod decimal; mod grouped; mod primitive; +pub(crate) use grouped::SUM_GROUPED_KERNEL; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -281,30 +282,6 @@ impl AggregateFnVTable for Sum { Ok(()) } - fn accumulate_grouped( - &self, - partials: &mut [Self::Partial], - batch: &Columnar, - group_ids: &[u32], - ctx: &mut ExecutionCtx, - ) -> VortexResult { - match batch { - Columnar::Canonical(Canonical::Primitive(p)) => { - grouped::accumulate_grouped_primitive(partials, p, group_ids, ctx)?; - Ok(true) - } - Columnar::Canonical(Canonical::Bool(b)) => { - grouped::accumulate_grouped_bool(partials, b, group_ids, ctx)?; - Ok(true) - } - // Decimal and constants still use the universal grouped fallback. - Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => Ok(false), - Columnar::Canonical(_) => { - vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()) - } - } - } - fn finalize(&self, partials: ArrayRef) -> VortexResult { Ok(partials) } @@ -439,6 +416,7 @@ mod tests { use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupIds; use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; @@ -616,10 +594,10 @@ mod tests { num_groups: usize, ) -> VortexResult { let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, values.dtype().clone())?; + let group_ids = GroupIds::from_iter(group_ids.iter().copied(), num_groups)?; acc.accumulate( values, - group_ids, - num_groups, + &group_ids, &mut LEGACY_SESSION.create_execution_ctx(), )?; acc.finish(num_groups) @@ -689,14 +667,16 @@ mod tests { let values1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - acc.accumulate(&values1, &[0, 0, 1, 1], 2, &mut ctx)?; + let group_ids1 = GroupIds::from_iter([0u32, 0, 1, 1], 2)?; + acc.accumulate(&values1, &group_ids1, &mut ctx)?; let result1 = acc.finish(2)?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result1, &expected1); let values2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&values2, &[0, 0], 1, &mut ctx)?; + let group_ids2 = GroupIds::from_iter([0u32, 0], 1)?; + acc.accumulate(&values2, &group_ids2, &mut ctx)?; let result2 = acc.finish(1)?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index e0b1d42e41e..51d47c33a2e 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -4,14 +4,19 @@ //! Pluggable aggregate function kernels used to provide encoding-specific implementations of //! aggregate functions. +use std::any::Any; use std::fmt::Debug; +use std::marker::PhantomData; -use vortex_buffer::Buffer; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::GroupIds; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -27,53 +32,110 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { ) -> VortexResult>; } -/// Partial grouped aggregate output produced by an encoding-specific grouped kernel. +/// A typed grouped aggregate kernel. /// -/// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the -/// corresponding dense group ordinal. The ids may repeat, omit, and reorder groups, but must be -/// valid slots in the accumulator's `0..num_groups` range. The grouped accumulator merges this -/// batch through `accumulate_partials`. -#[derive(Clone, Debug)] -pub struct GroupedAggregateKernelResult { - group_ids: Buffer, - partials: ArrayRef, +/// Implementations receive the concrete aggregate options and typed partial state. Return +/// `Ok(false)` when the kernel cannot handle the current values or group-id encodings. +pub trait GroupedAggregateKernel: 'static + Send + Sync + Debug { + /// Accumulate `batch` into `states` according to `group_ids`. + fn grouped_accumulate( + &self, + options: &V::Options, + states: &mut [V::Partial], + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult; +} + +/// Bridges a typed [`GroupedAggregateKernel`] to type-erased grouped kernel dispatch. +pub struct GroupedAggregateKernelAdapter { + kernel: K, + _phantom: PhantomData V>, } -impl GroupedAggregateKernelResult { - pub fn new(group_ids: Buffer, partials: ArrayRef) -> Self { +impl GroupedAggregateKernelAdapter { + /// Create a new adapter around `kernel`. + pub const fn new(kernel: K) -> Self { Self { - group_ids, - partials, + kernel, + _phantom: PhantomData, } } +} - pub fn group_ids(&self) -> &[u32] { - self.group_ids.as_ref() - } - - pub fn partials(&self) -> &ArrayRef { - &self.partials +impl Debug for GroupedAggregateKernelAdapter +where + V: AggregateFnVTable, + K: GroupedAggregateKernel, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GroupedAggregateKernelAdapter") + .field("kernel", &self.kernel) + .finish() } } /// A pluggable kernel for batch aggregation of many groups. /// -/// A grouped kernel can be registered for an aggregate function regardless of input encoding, or -/// for a specific aggregate function and array encoding. Encoding-specific kernels are matched on -/// the values array, not on a pre-grouped list wrapper. +/// A grouped kernel can be registered for an aggregate function regardless of input encodings, or +/// for a specific aggregate function plus values and/or group-id encoding. /// /// Kernels receive the same dense group ordinals that the caller passed to the grouped accumulator /// and may aggregate directly in the encoded domain. /// -/// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. +/// Return `Ok(false)` if the kernel cannot be applied to the given aggregate function or input +/// encodings. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate values into a partial-state batch keyed by dense group ordinal. - fn grouped_aggregate( + /// Accumulate values into type-erased partial state. + fn grouped_accumulate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + group_ids: &GroupIds, + states: &mut dyn Any, + ctx: &mut ExecutionCtx, + ) -> VortexResult; +} + +impl DynGroupedAggregateKernel for GroupedAggregateKernelAdapter +where + V: AggregateFnVTable, + K: GroupedAggregateKernel, +{ + fn grouped_accumulate( &self, aggregate_fn: &AggregateFnRef, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, + states: &mut dyn Any, ctx: &mut ExecutionCtx, - ) -> VortexResult>; + ) -> VortexResult { + let Some(options) = aggregate_fn.as_opt::() else { + return Ok(false); + }; + + let Some(states) = states.downcast_mut::>() else { + vortex_bail!( + "Grouped aggregate kernel for {} received incompatible partial state", + aggregate_fn.id() + ); + }; + + vortex_ensure!( + states.len() >= group_ids.num_groups(), + "Grouped aggregate kernel for {} received {} partial states for {} groups", + aggregate_fn.id(), + states.len(), + group_ids.num_groups() + ); + + self.kernel.grouped_accumulate( + options, + &mut states[..group_ids.num_groups()], + batch, + group_ids, + ctx, + ) + } } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 78b139bf36f..14a5ccb261d 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,6 +18,8 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; +use crate::aggregate_fn::fns::count::COUNT_GROUPED_KERNEL; +use crate::aggregate_fn::fns::count::Count; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -27,6 +29,7 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; +use crate::aggregate_fn::fns::sum::SUM_GROUPED_KERNEL; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; @@ -51,9 +54,7 @@ pub struct AggregateFnSession { registry: ArcSwapMap, kernels: ArcSwapMap, - grouped_kernels: ArcSwapMap, - grouped_encoding_kernels: - ArcSwapMap, + grouped_kernels: ArcSwapMap, } impl SessionVar for AggregateFnSession { @@ -67,7 +68,27 @@ impl SessionVar for AggregateFnSession { } type AggregateKernelKey = (ArrayId, Option); -type GroupedEncodingKernelKey = (ArrayId, AggregateFnId); + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +struct GroupedAggregateKernelKey { + aggregate_id: AggregateFnId, + values_id: Option, + group_ids_id: Option, +} + +impl GroupedAggregateKernelKey { + fn new( + aggregate_id: AggregateFnId, + values_id: Option, + group_ids_id: Option, + ) -> Self { + Self { + aggregate_id, + values_id, + group_ids_id, + } + } +} impl Default for AggregateFnSession { fn default() -> Self { @@ -75,7 +96,6 @@ impl Default for AggregateFnSession { registry: ArcSwapMap::default(), kernels: ArcSwapMap::default(), grouped_kernels: ArcSwapMap::default(), - grouped_encoding_kernels: ArcSwapMap::default(), }; // Register the built-in aggregate functions @@ -103,6 +123,8 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel); this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); + this.register_grouped_kernel(Count.id(), None, None, &COUNT_GROUPED_KERNEL); + this.register_grouped_kernel(Sum.id(), None, None, &SUM_GROUPED_KERNEL); this } @@ -156,54 +178,62 @@ impl AggregateFnSession { self.kernels.insert(id, kernel); } - /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. + /// Returns the grouped aggregate kernel registered for this aggregate and pair of encodings. /// - /// These kernels are independent of the element encoding and are checked for each element - /// representation, after any kernel registered for the current element encoding. + /// Lookup first checks the exact `(aggregate, values encoding, group ids encoding)` key, then + /// falls back through `(aggregate, values encoding, any group ids)`, `(aggregate, any values, + /// group ids encoding)`, and finally `(aggregate, any values, any group ids)`. pub fn find_grouped_kernel( &self, agg_fn_id: impl Into, + values_id: impl Into, + group_ids_id: impl Into, ) -> Option<&'static dyn DynGroupedAggregateKernel> { let fn_id = agg_fn_id.into(); - self.grouped_kernels - .read(|kernels| kernels.get(&fn_id).copied()) - } - - /// Registers a grouped aggregate kernel for an aggregate function. - pub fn register_grouped_kernel( - &self, - agg_fn_id: impl Into, - kernel: &'static dyn DynGroupedAggregateKernel, - ) { - let fn_id = agg_fn_id.into(); - self.grouped_kernels.insert(fn_id, kernel) + let values_id = values_id.into(); + let group_ids_id = group_ids_id.into(); + self.grouped_kernels.read(|kernels| { + kernels + .get(&GroupedAggregateKernelKey::new( + fn_id, + Some(values_id), + Some(group_ids_id), + )) + .or_else(|| { + kernels.get(&GroupedAggregateKernelKey::new( + fn_id, + Some(values_id), + None, + )) + }) + .or_else(|| { + kernels.get(&GroupedAggregateKernelKey::new( + fn_id, + None, + Some(group_ids_id), + )) + }) + .or_else(|| kernels.get(&GroupedAggregateKernelKey::new(fn_id, None, None))) + .copied() + }) } - /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// Registers a grouped aggregate kernel. /// - /// These kernels are matched against each intermediate element encoding while the grouped - /// accumulator executes the element array. - pub fn find_grouped_encoding_kernel( - &self, - array_id: impl Into, - agg_fn_id: impl Into, - ) -> Option<&'static dyn DynGroupedAggregateKernel> { - let id = array_id.into(); - let fn_id = agg_fn_id.into(); - self.grouped_encoding_kernels - .read(|kernels| kernels.get(&(id, fn_id)).copied()) - } - - /// Registers a grouped aggregate kernel for a specific aggregate function and array encoding. - pub fn register_grouped_encoding_kernel( + /// `values_id` and `group_ids_id` are optional wildcards. Passing `None` for either dimension + /// makes the kernel a fallback for that encoding dimension. + pub fn register_grouped_kernel( &self, - array_id: impl Into, agg_fn_id: impl Into, + values_id: Option, + group_ids_id: Option, kernel: &'static dyn DynGroupedAggregateKernel, ) { - let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_encoding_kernels.insert((id, fn_id), kernel) + self.grouped_kernels.insert( + GroupedAggregateKernelKey::new(fn_id, values_id, group_ids_id), + kernel, + ) } } @@ -215,3 +245,86 @@ pub trait AggregateFnSessionExt: SessionExt { } } impl AggregateFnSessionExt for S {} + +#[cfg(test)] +mod tests { + use std::any::Any; + + use vortex_error::VortexResult; + + use super::*; + use crate::ArrayRef; + use crate::ExecutionCtx; + use crate::aggregate_fn::AggregateFnRef; + use crate::aggregate_fn::GroupIds; + use crate::arrays::Constant; + use crate::arrays::Primitive; + + #[derive(Debug)] + struct TestGroupedKernel; + + impl DynGroupedAggregateKernel for TestGroupedKernel { + fn grouped_accumulate( + &self, + _aggregate_fn: &AggregateFnRef, + _batch: &ArrayRef, + _group_ids: &GroupIds, + _states: &mut dyn Any, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + } + + static GENERIC_KERNEL: TestGroupedKernel = TestGroupedKernel; + static GROUP_IDS_KERNEL: TestGroupedKernel = TestGroupedKernel; + static VALUES_KERNEL: TestGroupedKernel = TestGroupedKernel; + static EXACT_KERNEL: TestGroupedKernel = TestGroupedKernel; + + fn assert_same_kernel( + actual: Option<&'static dyn DynGroupedAggregateKernel>, + expected: &'static dyn DynGroupedAggregateKernel, + ) { + assert!(std::ptr::eq( + actual.expect("expected registered grouped kernel"), + expected + )); + } + + #[test] + fn grouped_kernel_lookup_prefers_exact_then_value_then_group_ids() { + let session = AggregateFnSession::default(); + let aggregate_id = AggregateFnId::new("test.grouped_lookup"); + let values_id = Primitive.id(); + let group_ids_id = Constant.id(); + + session.register_grouped_kernel(aggregate_id, None, None, &GENERIC_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &GENERIC_KERNEL, + ); + + session.register_grouped_kernel(aggregate_id, None, Some(group_ids_id), &GROUP_IDS_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &GROUP_IDS_KERNEL, + ); + + session.register_grouped_kernel(aggregate_id, Some(values_id), None, &VALUES_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &VALUES_KERNEL, + ); + + session.register_grouped_kernel( + aggregate_id, + Some(values_id), + Some(group_ids_id), + &EXACT_KERNEL, + ); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &EXACT_KERNEL, + ); + } +} diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 09eab6c5a9c..ab9edae5862 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -157,37 +157,6 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { ctx: &mut ExecutionCtx, ) -> VortexResult<()>; - /// Try to accumulate a raw values batch into dense per-group states before decompression. - /// - /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in - /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. - /// Returns `true` when the batch was fully handled. - fn try_accumulate_grouped( - &self, - _states: &mut [Self::Partial], - _batch: &ArrayRef, - _group_ids: &[u32], - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - Ok(false) - } - - /// Accumulate a canonical values batch into dense per-group states. - /// - /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in - /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. - /// Returns `true` when the batch was fully handled. The provided default preserves universal - /// correctness through [`crate::aggregate_fn::GroupedAccumulator`]'s fallback. - fn accumulate_grouped( - &self, - _states: &mut [Self::Partial], - _batch: &Columnar, - _group_ids: &[u32], - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - Ok(false) - } - /// Finalize an array of accumulator states into an array of aggregate results. /// /// The provides `states` array has dtype as specified by `state_dtype`, the result array