diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index d98b596eaa8..583b20a5805 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -17,12 +17,14 @@ use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::PrimitiveVTable; use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::patches::Patches; use vortex_array::patches::PatchesMetadata; +use vortex_array::require_child; use vortex_array::serde::ArrayChildren; use vortex_array::stats::ArrayStats; use vortex_array::stats::StatsSetRef; @@ -41,7 +43,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use vortex_mask::Mask; use vortex_session::VortexSession; use crate::alp_rd::kernel::PARENT_KERNELS; @@ -297,17 +298,12 @@ impl VTable for ALPRDVTable { } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - let left_parts = array.left_parts().clone().execute::(ctx)?; - let right_parts = array.right_parts().clone().execute::(ctx)?; + let left_parts = require_child!(array.left_parts(), 0 => PrimitiveVTable).clone(); + let right_parts = require_child!(array.right_parts(), 1 => PrimitiveVTable).clone(); // Decode the left_parts using our builtin dictionary. let left_parts_dict = array.left_parts_dictionary(); - - let validity = array - .left_parts() - .validity()? - .to_array(array.len()) - .execute::(ctx)?; + let validity = left_parts.validity_mask()?; let decoded_array = if array.is_f32() { PrimitiveArray::new( diff --git a/encodings/decimal-byte-parts/public-api.lock b/encodings/decimal-byte-parts/public-api.lock index 6ee1ebc7e5f..39b0f9e87b8 100644 --- a/encodings/decimal-byte-parts/public-api.lock +++ b/encodings/decimal-byte-parts/public-api.lock @@ -112,7 +112,7 @@ pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::deserialize(bytes: &[u pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::dtype(array: &vortex_decimal_byte_parts::DecimalBytePartsArray) -> &vortex_array::dtype::DType -pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::execute(array: &Self::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::execute(array: &Self::Array, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::execute_parent(array: &Self::Array, parent: &vortex_array::array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 79e24e1b0cd..155d326b8dc 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -8,6 +8,7 @@ mod slice; use std::hash::Hash; use prost::Message as _; +use vortex_array::AnyCanonical; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; @@ -18,8 +19,9 @@ use vortex_array::IntoArray; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::SerializeMetadata; +use vortex_array::arrays::ConstantVTable; use vortex_array::arrays::DecimalArray; -use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::PrimitiveVTable; use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::DecimalDType; @@ -190,8 +192,31 @@ impl VTable for DecimalBytePartsVTable { PARENT_RULES.evaluate(array, parent, child_idx) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - to_canonical_decimal(array, ctx).map(ExecutionStep::Done) + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + // Ensure msp (child 0) is a PrimitiveArray. + let prim = if let Some(primitive) = array.msp.as_opt::() { + primitive.clone() + } else if array.msp.is::() { + array.msp.to_canonical()?.into_primitive() + } else { + return Ok(ExecutionStep::execute_child::(0)); + }; + + Ok(ExecutionStep::Done(match_each_signed_integer_ptype!( + prim.ptype(), + |P| { + // SAFETY: The primitive array's buffer is already validated with correct type. + // The decimal dtype matches the array's dtype, and validity is preserved. + unsafe { + DecimalArray::new_unchecked( + prim.to_buffer::

(), + *array.decimal_dtype(), + prim.validity().clone(), + ) + } + .into_array() + } + ))) } fn execute_parent( @@ -277,30 +302,6 @@ impl DecimalBytePartsVTable { pub const ID: ArrayId = ArrayId::new_ref("vortex.decimal_byte_parts"); } -/// Converts a DecimalBytePartsArray to its canonical DecimalArray representation. -fn to_canonical_decimal( - array: &DecimalBytePartsArray, - ctx: &mut ExecutionCtx, -) -> VortexResult { - // TODO(joe): support parts len != 1 - let prim = array.msp.clone().execute::(ctx)?; - // Depending on the decimal type and the min/max of the primitive array we can choose - // the correct buffer size - - Ok(match_each_signed_integer_ptype!(prim.ptype(), |P| { - // SAFETY: The primitive array's buffer is already validated with correct type. - // The decimal dtype matches the array's dtype, and validity is preserved. - unsafe { - DecimalArray::new_unchecked( - prim.to_buffer::

(), - *array.decimal_dtype(), - prim.validity().clone(), - ) - } - .into_array() - })) -} - impl OperationsVTable for DecimalBytePartsVTable { fn scalar_at(array: &DecimalBytePartsArray, index: usize) -> VortexResult { // TODO(joe): support parts len != 1 diff --git a/encodings/fastlanes/src/for/array/for_compress.rs b/encodings/fastlanes/src/for/array/for_compress.rs index 95277505360..b70626e0642 100644 --- a/encodings/fastlanes/src/for/array/for_compress.rs +++ b/encodings/fastlanes/src/for/array/for_compress.rs @@ -68,8 +68,9 @@ mod test { use super::*; use crate::BitPackedArray; - use crate::r#for::array::for_decompress::decompress; + use crate::r#for::array::for_decompress::apply_reference; use crate::r#for::array::for_decompress::fused_decompress; + use crate::r#for::array::for_decompress::try_fused_decompress; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); @@ -169,7 +170,13 @@ mod test { let expected_unsigned = PrimitiveArray::from_iter(unsigned); assert_arrays_eq!(encoded, expected_unsigned); - let decompressed = decompress(&compressed, &mut SESSION.create_execution_ctx())?; + let mut ctx = SESSION.create_execution_ctx(); + let decompressed = if let Some(result) = try_fused_decompress(&compressed, &mut ctx)? { + result + } else { + let encoded = compressed.encoded().to_primitive(); + apply_reference(&compressed, encoded) + }; array .as_slice::() .iter() diff --git a/encodings/fastlanes/src/for/array/for_decompress.rs b/encodings/fastlanes/src/for/array/for_decompress.rs index f7292a5a06e..b462fa0646d 100644 --- a/encodings/fastlanes/src/for/array/for_decompress.rs +++ b/encodings/fastlanes/src/for/array/for_decompress.rs @@ -45,23 +45,28 @@ impl + FoR> UnpackStrategy for FoRStrategy } } -pub fn decompress(array: &FoRArray, ctx: &mut ExecutionCtx) -> VortexResult { - let ptype = array.ptype(); - - // Try to do fused unpack. +/// Try the fused BitPacked decompression path. Returns `None` if the child is not BitPacked +/// or the reference type is not unsigned. +pub fn try_fused_decompress( + array: &FoRArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { if array.reference_scalar().dtype().is_unsigned_int() && let Some(bp) = array.encoded().as_opt::() { return match_each_unsigned_integer_ptype!(array.ptype(), |T| { - fused_decompress::(array, bp, ctx) + fused_decompress::(array, bp, ctx).map(Some) }); } + Ok(None) +} - // TODO(ngates): Do we need this to be into_encoded() somehow? - let encoded = array.encoded().clone().execute::(ctx)?; +/// Apply the FoR reference value to an already-decoded PrimitiveArray. +pub fn apply_reference(array: &FoRArray, encoded: PrimitiveArray) -> PrimitiveArray { + let ptype = array.ptype(); let validity = encoded.validity().clone(); - Ok(match_each_integer_ptype!(ptype, |T| { + match_each_integer_ptype!(ptype, |T| { let min = array .reference_scalar() .as_primitive() @@ -75,7 +80,7 @@ pub fn decompress(array: &FoRArray, ctx: &mut ExecutionCtx) -> VortexResult VortexResult { - Ok(ExecutionStep::Done(decompress(array, ctx)?.into_array())) + // Try fused decompress with BitPacked child (no child execution needed). + if let Some(result) = try_fused_decompress(array, ctx)? { + return Ok(ExecutionStep::Done(result.into_array())); + } + + // If child is already a PrimitiveArray, add the reference value. + if array.encoded().is::() { + let encoded = array.encoded().as_::().clone(); + return Ok(ExecutionStep::Done( + apply_reference(array, encoded).into_array(), + )); + } + + // If child is a constant, compute the result as a constant. + if let Some(constant) = array.encoded().as_opt::() { + let scalar = constant.scalar(); + if scalar.is_null() { + return Ok(ExecutionStep::Done( + ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()) + .into_array(), + )); + } + return Ok(ExecutionStep::Done(match_each_integer_ptype!( + array.ptype(), + |T| { + let enc_val = scalar + .as_primitive() + .typed_value::() + .vortex_expect("constant must be non-null after check"); + let ref_val = array + .reference_scalar() + .as_primitive() + .typed_value::() + .vortex_expect("reference must be non-null"); + ConstantArray::new( + Scalar::primitive( + enc_val.wrapping_add(ref_val), + scalar.dtype().nullability(), + ), + array.len(), + ) + .into_array() + } + ))); + } + + // Otherwise, ask the scheduler to execute the child first. + Ok(ExecutionStep::execute_child::(0)) } fn execute_parent( diff --git a/encodings/zigzag/public-api.lock b/encodings/zigzag/public-api.lock index d3f94f2b73a..8b1a0c66775 100644 --- a/encodings/zigzag/public-api.lock +++ b/encodings/zigzag/public-api.lock @@ -100,7 +100,7 @@ pub fn vortex_zigzag::ZigZagVTable::deserialize(_bytes: &[u8], _dtype: &vortex_a pub fn vortex_zigzag::ZigZagVTable::dtype(array: &vortex_zigzag::ZigZagArray) -> &vortex_array::dtype::DType -pub fn vortex_zigzag::ZigZagVTable::execute(array: &Self::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_zigzag::ZigZagVTable::execute(array: &Self::Array, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_zigzag::ZigZagVTable::execute_parent(array: &Self::Array, parent: &vortex_array::array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index f8139246880..02195290777 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -3,6 +3,7 @@ use std::hash::Hash; +use vortex_array::AnyCanonical; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; @@ -12,6 +13,9 @@ use vortex_array::ExecutionCtx; use vortex_array::ExecutionStep; use vortex_array::IntoArray; use vortex_array::Precision; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ConstantVTable; +use vortex_array::arrays::PrimitiveVTable; use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::PType; @@ -149,10 +153,39 @@ impl VTable for ZigZagVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionStep::Done( - zigzag_decode(array.encoded().clone().execute(ctx)?).into_array(), - )) + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + // If child is already a PrimitiveArray, decode it directly. + if array.encoded().is::() { + let encoded = array.encoded().as_::().clone(); + return Ok(ExecutionStep::Done(zigzag_decode(encoded).into_array())); + } + + // If child is a constant, decode the scalar value. + if let Some(constant) = array.encoded().as_opt::() { + let scalar = constant.scalar(); + if scalar.is_null() { + return Ok(ExecutionStep::Done( + ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()) + .into_array(), + )); + } + let result = match_each_unsigned_integer_ptype!(scalar.as_primitive().ptype(), |P| { + let val = scalar + .as_primitive() + .typed_value::

() + .vortex_expect("constant must be non-null after check"); + Scalar::primitive( + <

::Int>::decode(val), + array.dtype().nullability(), + ) + }); + return Ok(ExecutionStep::Done( + ConstantArray::new(result, array.len()).into_array(), + )); + } + + // Otherwise, ask the scheduler to execute the child first. + Ok(ExecutionStep::execute_child::(0)) } fn reduce_parent( diff --git a/encodings/zstd/public-api.lock b/encodings/zstd/public-api.lock index 142a350991a..7490003ab69 100644 --- a/encodings/zstd/public-api.lock +++ b/encodings/zstd/public-api.lock @@ -200,7 +200,7 @@ pub fn vortex_zstd::ZstdVTable::deserialize(bytes: &[u8], _dtype: &vortex_array: pub fn vortex_zstd::ZstdVTable::dtype(array: &vortex_zstd::ZstdArray) -> &vortex_array::dtype::DType -pub fn vortex_zstd::ZstdVTable::execute(array: &Self::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_zstd::ZstdVTable::execute(array: &Self::Array, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_zstd::ZstdVTable::id(_array: &Self::Array) -> vortex_array::vtable::dyn_::ArrayId diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index 52665408abb..12ffb0f7fad 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -272,11 +272,8 @@ impl VTable for ZstdVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - array - .decompress()? - .execute::(ctx) - .map(ExecutionStep::Done) + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionStep::Done(array.decompress()?)) } fn reduce_parent( diff --git a/encodings/zstd/src/zstd_buffers.rs b/encodings/zstd/src/zstd_buffers.rs index ff733475a95..3fd2e84ce3e 100644 --- a/encodings/zstd/src/zstd_buffers.rs +++ b/encodings/zstd/src/zstd_buffers.rs @@ -470,9 +470,7 @@ impl VTable for ZstdBuffersVTable { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { let session = ctx.session(); let inner_array = array.decompress_and_build_inner(session)?; - inner_array - .execute::(ctx) - .map(ExecutionStep::Done) + Ok(ExecutionStep::Done(inner_array)) } } diff --git a/vortex-array/src/arrays/constant/mod.rs b/vortex-array/src/arrays/constant/mod.rs index 6aa94ab8565..e3e8d2ddea7 100644 --- a/vortex-array/src/arrays/constant/mod.rs +++ b/vortex-array/src/arrays/constant/mod.rs @@ -14,3 +14,4 @@ pub(crate) mod compute; mod vtable; pub use vtable::ConstantVTable; +pub(crate) use vtable::canonical::constant_canonicalize; diff --git a/vortex-array/src/arrays/dict/vtable/mod.rs b/vortex-array/src/arrays/dict/vtable/mod.rs index ffde0e4223b..215335e10fd 100644 --- a/vortex-array/src/arrays/dict/vtable/mod.rs +++ b/vortex-array/src/arrays/dict/vtable/mod.rs @@ -14,6 +14,7 @@ use vortex_session::VortexSession; use super::DictArray; use super::DictMetadata; use super::take_canonical; +use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; use crate::DeserializeMetadata; @@ -23,6 +24,7 @@ use crate::Precision; use crate::ProstMetadata; use crate::SerializeMetadata; use crate::arrays::ConstantArray; +use crate::arrays::PrimitiveVTable; use crate::arrays::dict::compute::rules::PARENT_RULES; use crate::buffer::BufferHandle; use crate::dtype::DType; @@ -32,6 +34,7 @@ use crate::executor::ExecutionCtx; use crate::executor::ExecutionStep; use crate::hash::ArrayEq; use crate::hash::ArrayHash; +use crate::require_child; use crate::scalar::Scalar; use crate::serde::ArrayChildren; use crate::stats::StatsSetRef; @@ -192,25 +195,29 @@ impl VTable for DictVTable { } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - if let Some(canonical) = execute_fast_path(array, ctx)? { - return Ok(ExecutionStep::Done(canonical)); + if array.is_empty() { + let result_dtype = array + .dtype() + .union_nullability(array.codes().dtype().nullability()); + return Ok(ExecutionStep::Done( + Canonical::empty(&result_dtype).into_array(), + )); } - // TODO(joe): if the values are constant return a constant - let values = array.values().clone().execute::(ctx)?; - let codes = array - .codes() - .clone() - .execute::(ctx)? - .into_primitive(); + let codes = require_child!(array.codes(), 0 => PrimitiveVTable); - // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to - // the filter function since they're typically optimised for this case. - // TODO(ngates): if indices min is quite high, we could slice self and offset the indices - // such that canonicalize does less work. + if codes.all_invalid()? { + return Ok(ExecutionStep::Done( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.codes.len()) + .into_array(), + )); + } + + let values = require_child!(array.values(), 1 => AnyCanonical); + let values = Canonical::from(values); Ok(ExecutionStep::Done( - take_canonical(values, &codes, ctx)?.into_array(), + take_canonical(values, codes, ctx)?.into_array(), )) } @@ -231,27 +238,3 @@ impl VTable for DictVTable { PARENT_KERNELS.execute(array, parent, child_idx, ctx) } } - -/// Check for fast-path execution conditions. -pub(super) fn execute_fast_path( - array: &DictArray, - _ctx: &mut ExecutionCtx, -) -> VortexResult> { - // Empty array - nothing to do - if array.is_empty() { - let result_dtype = array - .dtype() - .union_nullability(array.codes().dtype().nullability()); - return Ok(Some(Canonical::empty(&result_dtype).into_array())); - } - - // All codes are null - result is all nulls - if array.codes.all_invalid()? { - return Ok(Some( - ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.codes.len()) - .into_array(), - )); - } - - Ok(None) -} diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 28d31c60d2b..62770ca695b 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -13,12 +13,16 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_session::VortexSession; +use crate::AnyCanonical; use crate::ArrayEq; use crate::ArrayHash; use crate::ArrayRef; +use crate::Canonical; use crate::DynArray; use crate::IntoArray; use crate::Precision; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; use crate::arrays::filter::array::FilterArray; use crate::arrays::filter::execute::execute_filter; use crate::arrays::filter::execute::execute_filter_fast_paths; @@ -156,18 +160,30 @@ impl VTable for FilterVTable { } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - if let Some(canonical) = execute_filter_fast_paths(array, ctx)? { - return Ok(ExecutionStep::Done(canonical)); + if let Some(result) = execute_filter_fast_paths(array, ctx)? { + return Ok(ExecutionStep::Done(result)); } let Mask::Values(mask_values) = &array.mask else { unreachable!("`execute_filter_fast_paths` handles AllTrue and AllFalse") }; - // We rely on the optimization pass that runs prior to this execution for filter pushdown, - // so now we can just execute the filter without worrying. - Ok(ExecutionStep::Done( - execute_filter(array.child.clone().execute(ctx)?, mask_values).into_array(), - )) + // If child is already canonical, filter it directly. + if let Some(canonical) = array.child.as_opt::() { + return Ok(ExecutionStep::Done( + execute_filter(Canonical::from(canonical), mask_values).into_array(), + )); + } + + // If child is a constant, filtering just changes the length. + if let Some(constant) = array.child.as_opt::() { + return Ok(ExecutionStep::Done( + ConstantArray::new(constant.scalar().clone(), mask_values.true_count()) + .into_array(), + )); + } + + // Otherwise, ask the scheduler to execute the child first. + Ok(ExecutionStep::execute_child::(0)) } fn reduce_parent( diff --git a/vortex-array/src/arrays/masked/vtable/mod.rs b/vortex-array/src/arrays/masked/vtable/mod.rs index dafd40a50ee..f84fdf561c8 100644 --- a/vortex-array/src/arrays/masked/vtable/mod.rs +++ b/vortex-array/src/arrays/masked/vtable/mod.rs @@ -13,13 +13,16 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_panic; use vortex_session::VortexSession; +use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; use crate::EmptyMetadata; use crate::IntoArray; use crate::Precision; use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; use crate::arrays::MaskedArray; +use crate::arrays::constant::constant_canonicalize; use crate::arrays::masked::compute::rules::PARENT_RULES; use crate::arrays::masked::mask_validity_canonical; use crate::buffer::BufferHandle; @@ -178,10 +181,24 @@ impl VTable for MaskedVTable { // While we could manually convert the dtype, `mask_validity_canonical` is already O(1) for // `AllTrue` masks (no data copying), so there's no benefit. - let child = array.child().clone().execute::(ctx)?; - Ok(ExecutionStep::Done( - mask_validity_canonical(child, &validity_mask, ctx)?.into_array(), - )) + // If child is already canonical, apply the validity mask directly. + if let Some(canonical) = array.child().as_opt::() { + return Ok(ExecutionStep::Done( + mask_validity_canonical(Canonical::from(canonical), &validity_mask, ctx)? + .into_array(), + )); + } + + // If child is a constant, expand to canonical then apply the validity mask. + if let Some(constant) = array.child().as_opt::() { + let canonical = constant_canonicalize(constant)?; + return Ok(ExecutionStep::Done( + mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array(), + )); + } + + // Otherwise, ask the scheduler to execute the child first. + Ok(ExecutionStep::execute_child::(0)) } fn reduce_parent( diff --git a/vortex-array/src/arrays/shared/vtable.rs b/vortex-array/src/arrays/shared/vtable.rs index 12106400338..9ddd4b6541c 100644 --- a/vortex-array/src/arrays/shared/vtable.rs +++ b/vortex-array/src/arrays/shared/vtable.rs @@ -8,12 +8,15 @@ use vortex_error::VortexResult; use vortex_error::vortex_panic; use vortex_session::VortexSession; +use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; use crate::EmptyMetadata; use crate::ExecutionCtx; use crate::ExecutionStep; +use crate::IntoArray; use crate::Precision; +use crate::arrays::ConstantVTable; use crate::arrays::SharedArray; use crate::buffer::BufferHandle; use crate::dtype::DType; @@ -145,10 +148,15 @@ impl VTable for SharedVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - array - .get_or_compute(|source| source.clone().execute::(ctx)) - .map(ExecutionStep::Done) + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + let current = array.current_array_ref(); + if let Some(canonical) = current.as_opt::() { + return Ok(ExecutionStep::Done(Canonical::from(canonical).into_array())); + } + if current.as_opt::().is_some() { + return Ok(ExecutionStep::Done(current.clone())); + } + Ok(ExecutionStep::execute_child::(0)) } } impl OperationsVTable for SharedVTable { diff --git a/vortex-array/src/arrays/slice/vtable.rs b/vortex-array/src/arrays/slice/vtable.rs index a8aaafb1d61..2a526f35896 100644 --- a/vortex-array/src/arrays/slice/vtable.rs +++ b/vortex-array/src/arrays/slice/vtable.rs @@ -20,7 +20,10 @@ use crate::ArrayHash; use crate::ArrayRef; use crate::Canonical; use crate::DynArray; +use crate::IntoArray; use crate::Precision; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; use crate::arrays::slice::array::SliceArray; use crate::arrays::slice::rules::PARENT_RULES; use crate::buffer::BufferHandle; @@ -155,23 +158,25 @@ impl VTable for SliceVTable { Ok(()) } - fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - // Execute the child to get canonical form, then slice it - let Some(canonical) = array.child.as_opt::() else { - // If the child is not canonical, recurse. - return array - .child - .clone() - .execute::(ctx)? - .slice(array.slice_range().clone()) + fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { + // If child is already canonical, slice it directly. + if let Some(canonical) = array.child.as_opt::() { + // TODO(ngates): we should inline canonical slice logic here. + return Canonical::from(canonical) + .as_ref() + .slice(array.range.clone()) .map(ExecutionStep::Done); - }; + } + + // If child is a constant, slicing just changes the length. + if let Some(constant) = array.child.as_opt::() { + return Ok(ExecutionStep::Done( + ConstantArray::new(constant.scalar().clone(), array.range.len()).into_array(), + )); + } - // TODO(ngates): we should inline canonical slice logic here. - Canonical::from(canonical) - .as_ref() - .slice(array.range.clone()) - .map(ExecutionStep::Done) + // Otherwise, ask the scheduler to execute the child first. + Ok(ExecutionStep::execute_child::(0)) } fn reduce_parent( diff --git a/vortex-array/src/executor.rs b/vortex-array/src/executor.rs index da05450f8de..1e6ae528090 100644 --- a/vortex-array/src/executor.rs +++ b/vortex-array/src/executor.rs @@ -387,6 +387,23 @@ impl fmt::Debug for ExecutionStep { } } +/// Require that a child array matches `$M`. If it does, evaluates to the matched value. +/// Otherwise, early-returns `Ok(ExecutionStep::execute_child::<$M>($idx))`. +/// +/// ```ignore +/// let codes = require_child!(array.codes(), 0 => PrimitiveVTable); +/// let values = require_child!(array.values(), 1 => AnyCanonical); +/// ``` +#[macro_export] +macro_rules! require_child { + ($child:expr, $idx:expr => $M:ty) => { + match $child.as_opt::<$M>() { + Some(c) => c, + None => return Ok($crate::ExecutionStep::execute_child::<$M>($idx)), + } + }; +} + /// Extension trait for creating an execution context from a session. pub trait VortexSessionExecute { /// Create a new execution context from this session.